Skip to content

Commit 08e02c9

Browse files
author
Kevin Chang
committed
ffssn
1 parent 87618a1 commit 08e02c9

File tree

4 files changed

+667
-142
lines changed

4 files changed

+667
-142
lines changed

bindsnet/network/topology.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2049,3 +2049,131 @@ def reset_state_variables(self) -> None:
20492049
Contains resetting logic for the connection.
20502050
"""
20512051
super().reset_state_variables()
2052+
2053+
2054+
class ForwardForwardConnection(AbstractConnection):
2055+
"""
2056+
Connection class specifically designed for Forward-Forward training with arctangent surrogate gradients.
2057+
"""
2058+
2059+
def __init__(
2060+
self,
2061+
source: Nodes,
2062+
target: Nodes,
2063+
nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None,
2064+
weight_decay: float = 0.0,
2065+
spike_threshold: float = 1.0,
2066+
alpha: float = 2.0, # α parameter for arctangent surrogate
2067+
**kwargs,
2068+
) -> None:
2069+
super().__init__(source, target, nu, weight_decay, **kwargs)
2070+
2071+
# Initialize weights with gradient support
2072+
w = kwargs.get("w", None)
2073+
if w is None:
2074+
if (self.wmin == -np.inf).any() or (self.wmax == np.inf).any():
2075+
w = torch.clamp(torch.randn(source.n, target.n) * 0.1, self.wmin, self.wmax)
2076+
else:
2077+
w = self.wmin + (torch.randn(source.n, target.n) * 0.1) * (self.wmax - self.wmin)
2078+
else:
2079+
if (self.wmin == -np.inf).any() or (self.wmax == np.inf).any():
2080+
w = torch.clamp(w, self.wmin, self.wmax)
2081+
2082+
# CRITICAL: Enable gradients for Forward-Forward training
2083+
self.w = Parameter(w, requires_grad=True)
2084+
2085+
# Surrogate gradient parameters
2086+
self.spike_threshold = spike_threshold
2087+
self.alpha = alpha
2088+
2089+
# Track membrane potential for surrogate gradients
2090+
self.membrane_potential = None
2091+
2092+
def atan_surrogate_spike(self, x: torch.Tensor) -> torch.Tensor:
2093+
"""
2094+
Arctangent surrogate gradient function.
2095+
2096+
Forward pass: Heaviside step function shifted by threshold
2097+
Backward pass: Gradient of shifted arc-tan function with parameter α
2098+
"""
2099+
class AtanSurrogate(torch.autograd.Function):
2100+
@staticmethod
2101+
def forward(ctx, input, threshold, alpha):
2102+
ctx.save_for_backward(input)
2103+
ctx.threshold = threshold
2104+
ctx.alpha = alpha
2105+
# Forward: Heaviside step function shifted by threshold
2106+
return (input > threshold).float()
2107+
2108+
@staticmethod
2109+
def backward(ctx, grad_output):
2110+
input, = ctx.saved_tensors
2111+
grad_input = grad_output.clone()
2112+
# Backward: Gradient of shifted arc-tan function
2113+
# surrogate = 1 / (α * |input - threshold| + 1)
2114+
surrogate_grad = 1.0 / (ctx.alpha * torch.abs(input - ctx.threshold) + 1.0)
2115+
return grad_input * surrogate_grad, None, None
2116+
2117+
return AtanSurrogate.apply(x, self.spike_threshold, self.alpha)
2118+
2119+
def compute_with_surrogate(self, s: torch.Tensor) -> torch.Tensor:
2120+
"""
2121+
Compute pre-activations with arctangent surrogate gradients.
2122+
2123+
:param s: Incoming spikes [batch_size, source_neurons]
2124+
:return: Output spikes with surrogate gradients [batch_size, target_neurons]
2125+
"""
2126+
batch_size = s.shape[0]
2127+
2128+
# Initialize membrane potential if needed
2129+
if self.membrane_potential is None or self.membrane_potential.shape != (batch_size, self.target.n):
2130+
self.membrane_potential = torch.zeros(batch_size, self.target.n, device=s.device)
2131+
2132+
# Synaptic input: spikes @ weights
2133+
synaptic_input = torch.mm(s.float(), self.w)
2134+
2135+
# Simple LIF dynamics with decay (you can customize this)
2136+
decay_factor = 0.9 # Can be made configurable
2137+
self.membrane_potential = decay_factor * self.membrane_potential + synaptic_input
2138+
2139+
# Generate spikes with arctangent surrogate gradients
2140+
spikes = self.atan_surrogate_spike(self.membrane_potential)
2141+
2142+
# Reset mechanism: subtract threshold from membrane potential where spikes occurred
2143+
self.membrane_potential = self.membrane_potential - spikes * self.spike_threshold
2144+
2145+
return spikes
2146+
2147+
def compute(self, s: torch.Tensor) -> torch.Tensor:
2148+
"""
2149+
Standard compute method (calls compute_with_surrogate for FF training).
2150+
"""
2151+
return self.compute_with_surrogate(s)
2152+
2153+
def reset_membrane_potential(self):
2154+
"""Reset membrane potential (call between samples/batches)."""
2155+
self.membrane_potential = None
2156+
2157+
def update(self, **kwargs) -> None:
2158+
"""
2159+
Override standard BindsNET update - FF uses PyTorch optimizers.
2160+
"""
2161+
# Forward-Forward training uses PyTorch optimizers for weight updates
2162+
# So we don't need the standard BindsNET learning rule updates
2163+
pass
2164+
2165+
def normalize(self) -> None:
2166+
"""
2167+
Normalize weights so each target neuron has sum of connection weights equal to self.norm.
2168+
"""
2169+
if self.norm is not None:
2170+
w_abs_sum = self.w.abs().sum(0).unsqueeze(0)
2171+
w_abs_sum[w_abs_sum == 0] = 1.0
2172+
self.w.data *= self.norm / w_abs_sum
2173+
2174+
def reset_state_variables(self) -> None:
2175+
"""
2176+
Contains resetting logic for the connection.
2177+
"""
2178+
super().reset_state_variables()
2179+
self.reset_membrane_potential()

0 commit comments

Comments
 (0)