Skip to content

Commit 0308b27

Browse files
committed
SparseConnection support
1 parent dd03ea3 commit 0308b27

File tree

2 files changed

+36
-106
lines changed

2 files changed

+36
-106
lines changed

bindsnet/learning/learning.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,10 @@ def update(self) -> None:
9898
(self.connection.wmin != -np.inf).any()
9999
or (self.connection.wmax != np.inf).any()
100100
) and not isinstance(self, NoOp):
101-
self.connection.w.clamp_(self.connection.wmin, self.connection.wmax)
101+
if self.connection.w.is_sparse:
102+
raise Exception("SparseConnection isn't supported for wmin\\wmax")
103+
else:
104+
self.connection.w.clamp_(self.connection.wmin, self.connection.wmax)
102105

103106

104107
class NoOp(LearningRule):
@@ -396,7 +399,10 @@ def _connection_update(self, **kwargs) -> None:
396399
if self.nu[0].any():
397400
source_s = self.source.s.view(batch_size, -1).unsqueeze(2).float()
398401
target_x = self.target.x.view(batch_size, -1).unsqueeze(1) * self.nu[0]
399-
self.connection.w -= self.reduction(torch.bmm(source_s, target_x), dim=0)
402+
update = self.reduction(torch.bmm(source_s, target_x), dim=0)
403+
if self.connection.w.is_sparse:
404+
update = update.to_sparse()
405+
self.connection.w -= update
400406
del source_s, target_x
401407

402408
# Post-synaptic update.
@@ -405,7 +411,10 @@ def _connection_update(self, **kwargs) -> None:
405411
self.target.s.view(batch_size, -1).unsqueeze(1).float() * self.nu[1]
406412
)
407413
source_x = self.source.x.view(batch_size, -1).unsqueeze(2)
408-
self.connection.w += self.reduction(torch.bmm(source_x, target_s), dim=0)
414+
update = self.reduction(torch.bmm(source_x, target_s), dim=0)
415+
if self.connection.w.is_sparse:
416+
update = update.to_sparse()
417+
self.connection.w += update
409418
del source_x, target_s
410419

411420
super().update()
@@ -1113,10 +1122,14 @@ def _connection_update(self, **kwargs) -> None:
11131122

11141123
# Pre-synaptic update.
11151124
update = self.reduction(torch.bmm(source_s, target_x), dim=0)
1125+
if self.connection.w.is_sparse:
1126+
update = update.to_sparse()
11161127
self.connection.w += self.nu[0] * update
11171128

11181129
# Post-synaptic update.
11191130
update = self.reduction(torch.bmm(source_x, target_s), dim=0)
1131+
if self.connection.w.is_sparse:
1132+
update = update.to_sparse()
11201133
self.connection.w += self.nu[1] * update
11211134

11221135
super().update()
@@ -1542,8 +1555,10 @@ def _connection_update(self, **kwargs) -> None:
15421555
a_minus = torch.tensor(a_minus, device=self.connection.w.device)
15431556

15441557
# Compute weight update based on the eligibility value of the past timestep.
1545-
update = reward * self.eligibility
1546-
self.connection.w += self.nu[0] * self.reduction(update, dim=0)
1558+
update = self.reduction(reward * self.eligibility, dim=0)
1559+
if self.connection.w.is_sparse:
1560+
update = update.to_sparse()
1561+
self.connection.w += self.nu[0] * update
15471562

15481563
# Update P^+ and P^- values.
15491564
self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus)
@@ -2214,10 +2229,11 @@ def _connection_update(self, **kwargs) -> None:
22142229
self.eligibility_trace *= torch.exp(-self.connection.dt / self.tc_e_trace)
22152230
self.eligibility_trace += self.eligibility / self.tc_e_trace
22162231

2232+
update = self.nu[0] * self.connection.dt * reward * self.eligibility_trace
2233+
if self.connection.w.is_sparse:
2234+
update = update.to_sparse()
22172235
# Compute weight update.
2218-
self.connection.w += (
2219-
self.nu[0] * self.connection.dt * reward * self.eligibility_trace
2220-
)
2236+
self.connection.w += update
22212237

22222238
# Update P^+ and P^- values.
22232239
self.p_plus *= torch.exp(-self.connection.dt / self.tc_plus)
@@ -2936,6 +2952,9 @@ def _connection_update(self, **kwargs) -> None:
29362952
) * source_x[:, None]
29372953

29382954
# Compute weight update.
2939-
self.connection.w += self.nu[0] * reward * self.eligibility_trace
2955+
update = self.nu[0] * reward * self.eligibility_trace
2956+
if self.connection.w.is_sparse:
2957+
update = update.to_sparse()
2958+
self.connection.w += update
29402959

29412960
super().update()

bindsnet/network/topology.py

Lines changed: 8 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,13 @@ def update(self, **kwargs) -> None:
126126

127127
mask = kwargs.get("mask", None)
128128
if mask is not None:
129+
if self.w.is_sparse:
130+
raise Exception("Mask isn't supported for SparseConnection")
129131
self.w.masked_fill_(mask, 0)
130132

131133
if self.Dales_rule is not None:
134+
if self.w.is_sparse:
135+
raise Exception("Dales_rule isn't supported for SparseConnection")
132136
# weight that are negative and should be positive are set to 0
133137
self.w[self.w < 0 * self.Dales_rule.to(torch.float)] = 0
134138
# weight that are positive and should be negative are set to 0
@@ -1947,105 +1951,12 @@ def reset_state_variables(self) -> None:
19471951
super().reset_state_variables()
19481952

19491953

1950-
class SparseConnection(AbstractConnection):
1954+
class SparseConnection(Connection):
19511955
# language=rst
19521956
"""
19531957
Specifies sparse synapses between one or two populations of neurons.
19541958
"""
19551959

1956-
def __init__(
1957-
self,
1958-
source: Nodes,
1959-
target: Nodes,
1960-
nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None,
1961-
reduction: Optional[callable] = None,
1962-
weight_decay: float = None,
1963-
**kwargs,
1964-
) -> None:
1965-
# language=rst
1966-
"""
1967-
Instantiates a :code:`Connection` object with sparse weights.
1968-
1969-
:param source: A layer of nodes from which the connection originates.
1970-
:param target: A layer of nodes to which the connection connects.
1971-
:param nu: Learning rate for both pre- and post-synaptic events. It also
1972-
accepts a pair of tensors to individualize learning rates of each neuron.
1973-
In this case, their shape should be the same size as the connection weights.
1974-
:param reduction: Method for reducing parameter updates along the minibatch
1975-
dimension.
1976-
:param weight_decay: Constant multiple to decay weights by on each iteration.
1977-
1978-
Keyword arguments:
1979-
1980-
:param torch.Tensor w: Strengths of synapses. Must be in ``torch.sparse`` format
1981-
:param float sparsity: Fraction of sparse connections to use.
1982-
:param LearningRule update_rule: Modifies connection parameters according to
1983-
some rule.
1984-
:param float wmin: Minimum allowed value on the connection weights.
1985-
:param float wmax: Maximum allowed value on the connection weights.
1986-
:param float norm: Total weight per target neuron normalization constant.
1987-
"""
1988-
super().__init__(source, target, nu, reduction, weight_decay, **kwargs)
1989-
1990-
w = kwargs.get("w", None)
1991-
self.sparsity = kwargs.get("sparsity", None)
1992-
1993-
assert (
1994-
w is not None
1995-
and self.sparsity is None
1996-
or w is None
1997-
and self.sparsity is not None
1998-
), 'Only one of "weights" or "sparsity" must be specified'
1999-
2000-
if w is None and self.sparsity is not None:
2001-
i = torch.bernoulli(
2002-
1 - self.sparsity * torch.ones(*source.shape, *target.shape)
2003-
)
2004-
if (self.wmin == -np.inf).any() or (self.wmax == np.inf).any():
2005-
v = torch.clamp(
2006-
torch.rand(*source.shape, *target.shape), self.wmin, self.wmax
2007-
)[i.bool()]
2008-
else:
2009-
v = (
2010-
self.wmin
2011-
+ torch.rand(*source.shape, *target.shape) * (self.wmax - self.wmin)
2012-
)[i.bool()]
2013-
w = torch.sparse.FloatTensor(i.nonzero().t(), v)
2014-
elif w is not None and self.sparsity is None:
2015-
assert w.is_sparse, "Weight matrix is not sparse (see torch.sparse module)"
2016-
if self.wmin != -np.inf or self.wmax != np.inf:
2017-
w = torch.clamp(w, self.wmin, self.wmax)
2018-
2019-
self.w = Parameter(w, requires_grad=False)
2020-
2021-
def compute(self, s: torch.Tensor) -> torch.Tensor:
2022-
# language=rst
2023-
"""
2024-
Compute convolutional pre-activations given spikes using layer weights.
2025-
2026-
:param s: Incoming spikes.
2027-
:return: Incoming spikes multiplied by synaptic weights (with or without
2028-
decaying spike activation).
2029-
"""
2030-
return torch.mm(self.w, s.view(s.shape[1], 1).float()).squeeze(-1)
2031-
# return torch.mm(self.w, s.unsqueeze(-1).float()).squeeze(-1)
2032-
2033-
def update(self, **kwargs) -> None:
2034-
# language=rst
2035-
"""
2036-
Compute connection's update rule.
2037-
"""
2038-
2039-
def normalize(self) -> None:
2040-
# language=rst
2041-
"""
2042-
Normalize weights along the first axis according to total weight per target
2043-
neuron.
2044-
"""
2045-
2046-
def reset_state_variables(self) -> None:
2047-
# language=rst
2048-
"""
2049-
Contains resetting logic for the connection.
2050-
"""
2051-
super().reset_state_variables()
1960+
def __init__(self, *args, **kwargs):
1961+
super().__init__(*args, **kwargs)
1962+
self.w = Parameter(self.w.to_sparse(), requires_grad=False)

0 commit comments

Comments
 (0)