@@ -31,6 +31,7 @@ def __init__(
3131 enforce_polarity : Optional [bool ] = False ,
3232 decay : float = 0.0 ,
3333 parent_feature = None ,
34+ sparse : Optional [bool ] = False ,
3435 ** kwargs ,
3536 ) -> None :
3637 # language=rst
@@ -47,6 +48,7 @@ def __init__(
4748 dimension
4849 :param decay: Constant multiple to decay weights by on each iteration
4950 :param parent_feature: Parent feature to inherit :code:`value` from
51+ :param sparse: Should :code:`value` parameter be sparse tensor or not
5052 """
5153
5254 #### Initialize class variables ####
@@ -61,6 +63,7 @@ def __init__(
6163 self .reduction = reduction
6264 self .decay = decay
6365 self .parent_feature = parent_feature
66+ self .sparse = sparse
6467 self .kwargs = kwargs
6568
6669 ## Backend ##
@@ -119,6 +122,10 @@ def __init__(
119122 self .assert_valid_range ()
120123 if value is not None :
121124 self .assert_feature_in_range ()
125+ if self .sparse :
126+ self .value = self .value .to_sparse ()
127+ assert not getattr (self , 'enforce_polarity' , False ), \
128+ "enforce_polarity isn't supported for sparse tensors"
122129
123130 @abstractmethod
124131 def reset_state_variables (self ) -> None :
@@ -161,7 +168,10 @@ def prime_feature(self, connection, device, **kwargs) -> None:
161168
162169 # Check if values/norms are the correct shape
163170 if isinstance (self .value , torch .Tensor ):
164- assert tuple (self .value .shape ) == (connection .source .n , connection .target .n )
171+ if self .sparse :
172+ assert tuple (self .value .shape [1 :]) == (connection .source .n , connection .target .n )
173+ else :
174+ assert tuple (self .value .shape ) == (connection .source .n , connection .target .n )
165175
166176 if self .norm is not None and isinstance (self .norm , torch .Tensor ):
167177 assert self .norm .shape [0 ] == connection .target .n
@@ -214,9 +224,15 @@ def normalize(self) -> None:
214224 """
215225
216226 if self .norm is not None :
217- abs_sum = self .value .sum (0 ).unsqueeze (0 )
218- abs_sum [abs_sum == 0 ] = 1.0
219- self .value *= self .norm / abs_sum
227+ if self .sparse :
228+ abs_sum = self .value .sum (1 ).to_dense ()
229+ abs_sum [abs_sum == 0 ] = 1.0
230+ abs_sum = abs_sum .unsqueeze (1 ).expand (- 1 , * self .value .shape [1 :])
231+ self .value = self .value * (self .norm / abs_sum )
232+ else :
233+ abs_sum = self .value .sum (0 ).unsqueeze (0 )
234+ abs_sum [abs_sum == 0 ] = 1.0
235+ self .value *= self .norm / abs_sum
220236
221237 def degrade (self ) -> None :
222238 # language=rst
@@ -299,11 +315,17 @@ def assert_feature_in_range(self):
299315
300316 def assert_valid_shape (self , source_shape , target_shape , f ):
301317 # Multidimensional feat
302- if len (f .shape ) > 1 :
303- assert f .shape == (
318+ if (not self .sparse and len (f .shape ) > 1 ) or (self .sparse and len (f .shape [1 :]) > 1 ):
319+ if self .sparse :
320+ f_shape = f .shape [1 :]
321+ expected = ('batch_size' , source_shape , target_shape )
322+ else :
323+ f_shape = f .shape
324+ expected = (source_shape , target_shape )
325+ assert f_shape == (
304326 source_shape ,
305327 target_shape ,
306- ), f"Feature { self .name } has an incorrect shape of { f .shape } . Should be of shape { ( source_shape , target_shape ) } "
328+ ), f"Feature { self .name } has an incorrect shape of { f .shape } . Should be of shape { expected } "
307329 # Else assume scalar, which is a valid shape
308330
309331
@@ -319,6 +341,7 @@ def __init__(
319341 reduction : Optional [callable ] = None ,
320342 decay : float = 0.0 ,
321343 parent_feature = None ,
344+ sparse : Optional [bool ] = False
322345 ) -> None :
323346 # language=rst
324347 """
@@ -336,6 +359,7 @@ def __init__(
336359 dimension
337360 :param decay: Constant multiple to decay weights by on each iteration
338361 :param parent_feature: Parent feature to inherit :code:`value` from
362+ :param sparse: Should :code:`value` parameter be sparse tensor or not
339363 """
340364
341365 ### Assertions ###
@@ -349,10 +373,25 @@ def __init__(
349373 reduction = reduction ,
350374 decay = decay ,
351375 parent_feature = parent_feature ,
376+ sparse = sparse
377+ )
378+
379+ def sparse_bernoulli (self ):
380+ values = torch .bernoulli (self .value .values ())
381+ mask = values != 0
382+ indices = self .value .indices ()[:, mask ]
383+ non_zero = values [mask ]
384+ return torch .sparse_coo_tensor (
385+ indices ,
386+ non_zero ,
387+ self .value .size ()
352388 )
353389
354390 def compute (self , conn_spikes ) -> Union [torch .Tensor , float , int ]:
355- return conn_spikes * torch .bernoulli (self .value )
391+ if self .sparse :
392+ return conn_spikes * self .sparse_bernoulli ()
393+ else :
394+ return conn_spikes * torch .bernoulli (self .value )
356395
357396 def reset_state_variables (self ) -> None :
358397 pass
@@ -395,12 +434,14 @@ def __init__(
395434 self ,
396435 name : str ,
397436 value : Union [torch .Tensor , float , int ] = None ,
437+ sparse : Optional [bool ] = False
398438 ) -> None :
399439 # language=rst
400440 """
401441 Boolean mask which determines whether or not signals are allowed to traverse certain synapses.
402442 :param name: Name of the feature
403443 :param value: Boolean mask. :code:`True` means a signal can pass, :code:`False` means the synapse is impassable
444+ :param sparse: Should :code:`value` parameter be sparse tensor or not
404445 """
405446
406447 ### Assertions ###
@@ -419,11 +460,9 @@ def __init__(
419460 super ().__init__ (
420461 name = name ,
421462 value = value ,
463+ sparse = sparse
422464 )
423465
424- self .name = name
425- self .value = value
426-
427466 def compute (self , conn_spikes ) -> torch .Tensor :
428467 return conn_spikes * self .value
429468
@@ -505,6 +544,7 @@ def __init__(
505544 reduction : Optional [callable ] = None ,
506545 enforce_polarity : Optional [bool ] = False ,
507546 decay : float = 0.0 ,
547+ sparse : Optional [bool ] = False
508548 ) -> None :
509549 # language=rst
510550 """
@@ -523,6 +563,7 @@ def __init__(
523563 dimension
524564 :param enforce_polarity: Will prevent synapses from changing signs if :code:`True`
525565 :param decay: Constant multiple to decay weights by on each iteration
566+ :param sparse: Should :code:`value` parameter be sparse tensor or not
526567 """
527568
528569 self .norm_frequency = norm_frequency
@@ -536,6 +577,7 @@ def __init__(
536577 nu = nu ,
537578 reduction = reduction ,
538579 decay = decay ,
580+ sparse = sparse
539581 )
540582
541583 def reset_state_variables (self ) -> None :
@@ -589,6 +631,7 @@ def __init__(
589631 value : Union [torch .Tensor , float , int ] = None ,
590632 range : Optional [Sequence [float ]] = None ,
591633 norm : Optional [Union [torch .Tensor , float , int ]] = None ,
634+ sparse : Optional [bool ] = False
592635 ) -> None :
593636 # language=rst
594637 """
@@ -598,13 +641,15 @@ def __init__(
598641 :param range: Range of acceptable values for the :code:`value` parameter
599642 :param norm: Value which all values in :code:`value` will sum to. Normalization of values occurs after each sample
600643 and after the value has been updated by the learning rule (if there is one)
644+ :param sparse: Should :code:`value` parameter be sparse tensor or not
601645 """
602646
603647 super ().__init__ (
604648 name = name ,
605649 value = value ,
606650 range = [- torch .inf , + torch .inf ] if range is None else range ,
607651 norm = norm ,
652+ sparse = sparse
608653 )
609654
610655 def reset_state_variables (self ) -> None :
@@ -629,15 +674,17 @@ def __init__(
629674 name : str ,
630675 value : Union [torch .Tensor , float , int ] = None ,
631676 range : Optional [Sequence [float ]] = None ,
677+ sparse : Optional [bool ] = False
632678 ) -> None :
633679 # language=rst
634680 """
635681 Adds scalars to signals
636682 :param name: Name of the feature
637683 :param value: Values to scale signals by
684+ :param sparse: Should :code:`value` parameter be sparse tensor or not
638685 """
639686
640- super ().__init__ (name = name , value = value , range = range )
687+ super ().__init__ (name = name , value = value , range = range , sparse = sparse )
641688
642689 def reset_state_variables (self ) -> None :
643690 pass
@@ -666,6 +713,7 @@ def __init__(
666713 value : Union [torch .Tensor , float , int ] = None ,
667714 degrade_function : callable = None ,
668715 parent_feature : Optional [AbstractFeature ] = None ,
716+ sparse : Optional [bool ] = False
669717 ) -> None :
670718 # language=rst
671719 """
@@ -676,10 +724,11 @@ def __init__(
676724 :param degrade_function: Callable function which takes a single argument (:code:`value`) and returns a tensor or
677725 constant to be *subtracted* from the propagating spikes.
678726 :param parent_feature: Parent feature with desired :code:`value` to inherit
727+ :param sparse: Should :code:`value` parameter be sparse tensor or not
679728 """
680729
681730 # Note: parent_feature will override value. See abstract constructor
682- super ().__init__ (name = name , value = value , parent_feature = parent_feature )
731+ super ().__init__ (name = name , value = value , parent_feature = parent_feature , sparse = sparse )
683732
684733 self .degrade_function = degrade_function
685734
@@ -698,6 +747,7 @@ def __init__(
698747 ann_values : Union [list , tuple ] = None ,
699748 const_update_rate : float = 0.1 ,
700749 const_decay : float = 0.001 ,
750+ sparse : Optional [bool ] = False
701751 ) -> None :
702752 # language=rst
703753 """
@@ -708,6 +758,7 @@ def __init__(
708758 :param value: Values to be use to build an initial mask for the synapses.
709759 :param const_update_rate: The mask upatate rate of the ANN decision.
710760 :param const_decay: The spontaneous activation of the synapses.
761+ :param sparse: Should :code:`value` parameter be sparse tensor or not
711762 """
712763
713764 # Define the ANN
@@ -743,16 +794,18 @@ def forward(self, x):
743794 self .const_update_rate = const_update_rate
744795 self .const_decay = const_decay
745796
746- super ().__init__ (name = name , value = value )
797+ super ().__init__ (name = name , value = value , sparse = sparse )
747798
748799 def compute (self , conn_spikes ) -> Union [torch .Tensor , float , int ]:
749800
750801 # Update the spike buffer
751802 if self .start_counter == False or conn_spikes .sum () > 0 :
752803 self .start_counter = True
753- self .spike_buffer [:, self .counter % self .spike_buffer .shape [1 ]] = (
754- conn_spikes .flatten ()
755- )
804+ if self .sparse :
805+ flat_conn_spikes = conn_spikes .to_dense ().flatten ()
806+ else :
807+ flat_conn_spikes = conn_spikes .flatten ()
808+ self .spike_buffer [:, self .counter % self .spike_buffer .shape [1 ]] = flat_conn_spikes
756809 self .counter += 1
757810
758811 # Update the masks
@@ -767,6 +820,8 @@ def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]:
767820
768821 # self.mask = torch.clamp(self.mask, -1, 1)
769822 self .value = (self .mask > 0 ).float ()
823+ if self .sparse :
824+ self .value = self .value .to_sparse ()
770825
771826 return conn_spikes * self .value
772827
@@ -788,6 +843,7 @@ def __init__(
788843 ann_values : Union [list , tuple ] = None ,
789844 const_update_rate : float = 0.1 ,
790845 const_decay : float = 0.01 ,
846+ sparse : Optional [bool ] = False
791847 ) -> None :
792848 # language=rst
793849 """
@@ -798,6 +854,7 @@ def __init__(
798854 :param value: Values to be use to build an initial mask for the synapses.
799855 :param const_update_rate: The mask upatate rate of the ANN decision.
800856 :param const_decay: The spontaneous activation of the synapses.
857+ :param sparse: Should :code:`value` parameter be sparse tensor or not
801858 """
802859
803860 # Define the ANN
@@ -833,16 +890,18 @@ def forward(self, x):
833890 self .const_update_rate = const_update_rate
834891 self .const_decay = const_decay
835892
836- super ().__init__ (name = name , value = value )
893+ super ().__init__ (name = name , value = value , sparse = sparse )
837894
838895 def compute (self , conn_spikes ) -> Union [torch .Tensor , float , int ]:
839896
840897 # Update the spike buffer
841898 if self .start_counter == False or conn_spikes .sum () > 0 :
842899 self .start_counter = True
843- self .spike_buffer [:, self .counter % self .spike_buffer .shape [1 ]] = (
844- conn_spikes .flatten ()
845- )
900+ if self .sparse :
901+ flat_conn_spikes = conn_spikes .to_dense ().flatten ()
902+ else :
903+ flat_conn_spikes = conn_spikes .flatten ()
904+ self .spike_buffer [:, self .counter % self .spike_buffer .shape [1 ]] = flat_conn_spikes
846905 self .counter += 1
847906
848907 # Update the masks
@@ -857,6 +916,8 @@ def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]:
857916
858917 # self.mask = torch.clamp(self.mask, -1, 1)
859918 self .value = (self .mask > 0 ).float ()
919+ if self .sparse :
920+ self .value = self .value .to_sparse ()
860921
861922 return conn_spikes * self .value
862923
0 commit comments