@@ -126,9 +126,13 @@ def update(self, **kwargs) -> None:
126126
127127 mask = kwargs .get ("mask" , None )
128128 if mask is not None :
129+ if self .w .is_sparse :
130+ raise Exception ("Mask isn't supported for SparseConnection" )
129131 self .w .masked_fill_ (mask , 0 )
130132
131133 if self .Dales_rule is not None :
134+ if self .w .is_sparse :
135+ raise Exception ("Dales_rule isn't supported for SparseConnection" )
132136 # weight that are negative and should be positive are set to 0
133137 self .w [self .w < 0 * self .Dales_rule .to (torch .float )] = 0
134138 # weight that are positive and should be negative are set to 0
@@ -1947,105 +1951,12 @@ def reset_state_variables(self) -> None:
19471951 super ().reset_state_variables ()
19481952
19491953
1950- class SparseConnection (AbstractConnection ):
1954+ class SparseConnection (Connection ):
19511955 # language=rst
19521956 """
19531957 Specifies sparse synapses between one or two populations of neurons.
19541958 """
19551959
1956- def __init__ (
1957- self ,
1958- source : Nodes ,
1959- target : Nodes ,
1960- nu : Optional [Union [float , Sequence [float ], Sequence [torch .Tensor ]]] = None ,
1961- reduction : Optional [callable ] = None ,
1962- weight_decay : float = None ,
1963- ** kwargs ,
1964- ) -> None :
1965- # language=rst
1966- """
1967- Instantiates a :code:`Connection` object with sparse weights.
1968-
1969- :param source: A layer of nodes from which the connection originates.
1970- :param target: A layer of nodes to which the connection connects.
1971- :param nu: Learning rate for both pre- and post-synaptic events. It also
1972- accepts a pair of tensors to individualize learning rates of each neuron.
1973- In this case, their shape should be the same size as the connection weights.
1974- :param reduction: Method for reducing parameter updates along the minibatch
1975- dimension.
1976- :param weight_decay: Constant multiple to decay weights by on each iteration.
1977-
1978- Keyword arguments:
1979-
1980- :param torch.Tensor w: Strengths of synapses. Must be in ``torch.sparse`` format
1981- :param float sparsity: Fraction of sparse connections to use.
1982- :param LearningRule update_rule: Modifies connection parameters according to
1983- some rule.
1984- :param float wmin: Minimum allowed value on the connection weights.
1985- :param float wmax: Maximum allowed value on the connection weights.
1986- :param float norm: Total weight per target neuron normalization constant.
1987- """
1988- super ().__init__ (source , target , nu , reduction , weight_decay , ** kwargs )
1989-
1990- w = kwargs .get ("w" , None )
1991- self .sparsity = kwargs .get ("sparsity" , None )
1992-
1993- assert (
1994- w is not None
1995- and self .sparsity is None
1996- or w is None
1997- and self .sparsity is not None
1998- ), 'Only one of "weights" or "sparsity" must be specified'
1999-
2000- if w is None and self .sparsity is not None :
2001- i = torch .bernoulli (
2002- 1 - self .sparsity * torch .ones (* source .shape , * target .shape )
2003- )
2004- if (self .wmin == - np .inf ).any () or (self .wmax == np .inf ).any ():
2005- v = torch .clamp (
2006- torch .rand (* source .shape , * target .shape ), self .wmin , self .wmax
2007- )[i .bool ()]
2008- else :
2009- v = (
2010- self .wmin
2011- + torch .rand (* source .shape , * target .shape ) * (self .wmax - self .wmin )
2012- )[i .bool ()]
2013- w = torch .sparse .FloatTensor (i .nonzero ().t (), v )
2014- elif w is not None and self .sparsity is None :
2015- assert w .is_sparse , "Weight matrix is not sparse (see torch.sparse module)"
2016- if self .wmin != - np .inf or self .wmax != np .inf :
2017- w = torch .clamp (w , self .wmin , self .wmax )
2018-
2019- self .w = Parameter (w , requires_grad = False )
2020-
2021- def compute (self , s : torch .Tensor ) -> torch .Tensor :
2022- # language=rst
2023- """
2024- Compute convolutional pre-activations given spikes using layer weights.
2025-
2026- :param s: Incoming spikes.
2027- :return: Incoming spikes multiplied by synaptic weights (with or without
2028- decaying spike activation).
2029- """
2030- return torch .mm (self .w , s .view (s .shape [1 ], 1 ).float ()).squeeze (- 1 )
2031- # return torch.mm(self.w, s.unsqueeze(-1).float()).squeeze(-1)
2032-
2033- def update (self , ** kwargs ) -> None :
2034- # language=rst
2035- """
2036- Compute connection's update rule.
2037- """
2038-
2039- def normalize (self ) -> None :
2040- # language=rst
2041- """
2042- Normalize weights along the first axis according to total weight per target
2043- neuron.
2044- """
2045-
2046- def reset_state_variables (self ) -> None :
2047- # language=rst
2048- """
2049- Contains resetting logic for the connection.
2050- """
2051- super ().reset_state_variables ()
1960+ def __init__ (self , * args , ** kwargs ):
1961+ super ().__init__ (* args , ** kwargs )
1962+ self .w = Parameter (self .w .to_sparse (), requires_grad = False )
0 commit comments