@@ -26,6 +26,11 @@ namespace NeuralAudio
2626 class Conv1DT
2727 {
2828 public:
29+ size_t GetNumWeights ()
30+ {
31+ return OutChannels * InChannels * KernelSize + (DoBias ? OutChannels : 0 );
32+ }
33+
2934 void SetWeights (std::vector<float >::iterator& inWeights)
3035 {
3136 weights.resize (KernelSize);
@@ -70,6 +75,11 @@ namespace NeuralAudio
7075 class DenseLayerT
7176 {
7277 public:
78+ size_t GetNumWeights ()
79+ {
80+ return OutSize * InSize + (DoBias ? OutSize : 0 );
81+ }
82+
7383 void SetWeights (std::vector<float >::iterator& inWeights)
7484 {
7585 for (size_t i = 0 ; i < OutSize; i++)
@@ -168,6 +178,11 @@ namespace NeuralAudio
168178#endif
169179 }
170180
181+ size_t GetNumWeights ()
182+ {
183+ return conv1D.GetNumWeights () + inputMixin.GetNumWeights () + oneByOne.GetNumWeights ();
184+ }
185+
171186 void SetWeights (std::vector<float >::iterator& weights)
172187 {
173188 conv1D.SetWeights (weights);
@@ -284,6 +299,20 @@ namespace NeuralAudio
284299 return allocNum;
285300 }
286301
302+ size_t GetNumWeights ()
303+ {
304+ size_t numWeights = rechannel.GetNumWeights ();
305+
306+ ForEachIndex<numLayers>([&](auto layerIndex)
307+ {
308+ numWeights += std::get<layerIndex>(layers).GetNumWeights ();
309+ });
310+
311+ numWeights += headRechannel.GetNumWeights ();
312+
313+ return numWeights;
314+ }
315+
287316 void SetWeights (std::vector<float >::iterator& weights)
288317 {
289318 rechannel.SetWeights (weights);
@@ -357,26 +386,41 @@ namespace NeuralAudio
357386
358387 allocNum = std::get<layerIndex>(layerArrays).AllocBuffers (allocNum);
359388 });
360-
361389 }
362390
363- void SetWeights (std::vector< float > weights )
391+ size_t GetNumWeights ( )
364392 {
365- std::vector< float >::iterator it = weights. begin () ;
393+ size_t numWeights = 0 ;
366394
367395 ForEachIndex<sizeof ...(LayerArrays)>([&](auto layerIndex)
368396 {
369- std::get<layerIndex>(layerArrays).SetWeights (it );
397+ numWeights += std::get<layerIndex>(layerArrays).GetNumWeights ( );
370398 });
371399
372- headScale = *(it++);
400+ numWeights++; // headScale;
401+
402+ return numWeights;
403+ }
404+
405+ void SetWeights (std::vector<float > weights)
406+ {
407+ size_t numWeights = GetNumWeights ();
373408
374- if (std::distance (weights. begin (), it) != ( long ) weights.size ())
409+ if (numWeights != weights.size ())
375410 {
376411 std::stringstream str;
377- str << " Wrong number of weights. Remaining: " << std::distance ( weights.begin (), it );
412+ str << " Wrong number of weights. Expected " << numWeights << " but got " << weights.size ( );
378413 throw std::runtime_error (str.str ());
379414 }
415+
416+ std::vector<float >::iterator it = weights.begin ();
417+
418+ ForEachIndex<sizeof ...(LayerArrays)>([&](auto layerIndex)
419+ {
420+ std::get<layerIndex>(layerArrays).SetWeights (it);
421+ });
422+
423+ headScale = *(it++);
380424 }
381425
382426 size_t GetMaxFrames ()
0 commit comments