@@ -2049,3 +2049,131 @@ def reset_state_variables(self) -> None:
20492049 Contains resetting logic for the connection.
20502050 """
20512051 super ().reset_state_variables ()
2052+
2053+
2054+ class ForwardForwardConnection (AbstractConnection ):
2055+ """
2056+ Connection class specifically designed for Forward-Forward training with arctangent surrogate gradients.
2057+ """
2058+
2059+ def __init__ (
2060+ self ,
2061+ source : Nodes ,
2062+ target : Nodes ,
2063+ nu : Optional [Union [float , Sequence [float ], Sequence [torch .Tensor ]]] = None ,
2064+ weight_decay : float = 0.0 ,
2065+ spike_threshold : float = 1.0 ,
2066+ alpha : float = 2.0 , # α parameter for arctangent surrogate
2067+ ** kwargs ,
2068+ ) -> None :
2069+ super ().__init__ (source , target , nu , weight_decay , ** kwargs )
2070+
2071+ # Initialize weights with gradient support
2072+ w = kwargs .get ("w" , None )
2073+ if w is None :
2074+ if (self .wmin == - np .inf ).any () or (self .wmax == np .inf ).any ():
2075+ w = torch .clamp (torch .randn (source .n , target .n ) * 0.1 , self .wmin , self .wmax )
2076+ else :
2077+ w = self .wmin + (torch .randn (source .n , target .n ) * 0.1 ) * (self .wmax - self .wmin )
2078+ else :
2079+ if (self .wmin == - np .inf ).any () or (self .wmax == np .inf ).any ():
2080+ w = torch .clamp (w , self .wmin , self .wmax )
2081+
2082+ # CRITICAL: Enable gradients for Forward-Forward training
2083+ self .w = Parameter (w , requires_grad = True )
2084+
2085+ # Surrogate gradient parameters
2086+ self .spike_threshold = spike_threshold
2087+ self .alpha = alpha
2088+
2089+ # Track membrane potential for surrogate gradients
2090+ self .membrane_potential = None
2091+
2092+ def atan_surrogate_spike (self , x : torch .Tensor ) -> torch .Tensor :
2093+ """
2094+ Arctangent surrogate gradient function.
2095+
2096+ Forward pass: Heaviside step function shifted by threshold
2097+ Backward pass: Gradient of shifted arc-tan function with parameter α
2098+ """
2099+ class AtanSurrogate (torch .autograd .Function ):
2100+ @staticmethod
2101+ def forward (ctx , input , threshold , alpha ):
2102+ ctx .save_for_backward (input )
2103+ ctx .threshold = threshold
2104+ ctx .alpha = alpha
2105+ # Forward: Heaviside step function shifted by threshold
2106+ return (input > threshold ).float ()
2107+
2108+ @staticmethod
2109+ def backward (ctx , grad_output ):
2110+ input , = ctx .saved_tensors
2111+ grad_input = grad_output .clone ()
2112+ # Backward: Gradient of shifted arc-tan function
2113+ # surrogate = 1 / (α * |input - threshold| + 1)
2114+ surrogate_grad = 1.0 / (ctx .alpha * torch .abs (input - ctx .threshold ) + 1.0 )
2115+ return grad_input * surrogate_grad , None , None
2116+
2117+ return AtanSurrogate .apply (x , self .spike_threshold , self .alpha )
2118+
2119+ def compute_with_surrogate (self , s : torch .Tensor ) -> torch .Tensor :
2120+ """
2121+ Compute pre-activations with arctangent surrogate gradients.
2122+
2123+ :param s: Incoming spikes [batch_size, source_neurons]
2124+ :return: Output spikes with surrogate gradients [batch_size, target_neurons]
2125+ """
2126+ batch_size = s .shape [0 ]
2127+
2128+ # Initialize membrane potential if needed
2129+ if self .membrane_potential is None or self .membrane_potential .shape != (batch_size , self .target .n ):
2130+ self .membrane_potential = torch .zeros (batch_size , self .target .n , device = s .device )
2131+
2132+ # Synaptic input: spikes @ weights
2133+ synaptic_input = torch .mm (s .float (), self .w )
2134+
2135+ # Simple LIF dynamics with decay (you can customize this)
2136+ decay_factor = 0.9 # Can be made configurable
2137+ self .membrane_potential = decay_factor * self .membrane_potential + synaptic_input
2138+
2139+ # Generate spikes with arctangent surrogate gradients
2140+ spikes = self .atan_surrogate_spike (self .membrane_potential )
2141+
2142+ # Reset mechanism: subtract threshold from membrane potential where spikes occurred
2143+ self .membrane_potential = self .membrane_potential - spikes * self .spike_threshold
2144+
2145+ return spikes
2146+
2147+ def compute (self , s : torch .Tensor ) -> torch .Tensor :
2148+ """
2149+ Standard compute method (calls compute_with_surrogate for FF training).
2150+ """
2151+ return self .compute_with_surrogate (s )
2152+
2153+ def reset_membrane_potential (self ):
2154+ """Reset membrane potential (call between samples/batches)."""
2155+ self .membrane_potential = None
2156+
2157+ def update (self , ** kwargs ) -> None :
2158+ """
2159+ Override standard BindsNET update - FF uses PyTorch optimizers.
2160+ """
2161+ # Forward-Forward training uses PyTorch optimizers for weight updates
2162+ # So we don't need the standard BindsNET learning rule updates
2163+ pass
2164+
2165+ def normalize (self ) -> None :
2166+ """
2167+ Normalize weights so each target neuron has sum of connection weights equal to self.norm.
2168+ """
2169+ if self .norm is not None :
2170+ w_abs_sum = self .w .abs ().sum (0 ).unsqueeze (0 )
2171+ w_abs_sum [w_abs_sum == 0 ] = 1.0
2172+ self .w .data *= self .norm / w_abs_sum
2173+
2174+ def reset_state_variables (self ) -> None :
2175+ """
2176+ Contains resetting logic for the connection.
2177+ """
2178+ super ().reset_state_variables ()
2179+ self .reset_membrane_potential ()
0 commit comments