Skip to content

Commit 14fb7bf

Browse files
committed
Removed redundant returns. Check for correct number of weights up front.
1 parent 89fb36a commit 14fb7bf

3 files changed

Lines changed: 51 additions & 11 deletions

File tree

NeuralAudio/InternalModel.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ namespace NeuralAudio
2020
ReadKerasConfig(modelJson);
2121

2222
return CreateModelFromKerasJson(modelJson);
23-
24-
return true;
2523
}
2624

2725
virtual bool CreateModelFromKerasJson(const nlohmann::json& modelJson)

NeuralAudio/RTNeuralModel.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@ namespace NeuralAudio
4545
ReadKerasConfig(modelJson);
4646

4747
return CreateModelFromKerasJson(modelJson);
48-
49-
return true;
5048
}
5149

5250
virtual bool CreateModelFromKerasJson(const nlohmann::json& modelJson)

NeuralAudio/WaveNet.h

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ namespace NeuralAudio
2626
class Conv1DT
2727
{
2828
public:
29+
size_t GetNumWeights()
30+
{
31+
return OutChannels * InChannels * KernelSize + (DoBias ? OutChannels : 0);
32+
}
33+
2934
void SetWeights(std::vector<float>::iterator& inWeights)
3035
{
3136
weights.resize(KernelSize);
@@ -70,6 +75,11 @@ namespace NeuralAudio
7075
class DenseLayerT
7176
{
7277
public:
78+
size_t GetNumWeights()
79+
{
80+
return OutSize * InSize + (DoBias ? OutSize : 0);
81+
}
82+
7383
void SetWeights(std::vector<float>::iterator& inWeights)
7484
{
7585
for (size_t i = 0; i < OutSize; i++)
@@ -168,6 +178,11 @@ namespace NeuralAudio
168178
#endif
169179
}
170180

181+
size_t GetNumWeights()
182+
{
183+
return conv1D.GetNumWeights() + inputMixin.GetNumWeights() + oneByOne.GetNumWeights();
184+
}
185+
171186
void SetWeights(std::vector<float>::iterator& weights)
172187
{
173188
conv1D.SetWeights(weights);
@@ -284,6 +299,20 @@ namespace NeuralAudio
284299
return allocNum;
285300
}
286301

302+
size_t GetNumWeights()
303+
{
304+
size_t numWeights = rechannel.GetNumWeights();
305+
306+
ForEachIndex<numLayers>([&](auto layerIndex)
307+
{
308+
numWeights += std::get<layerIndex>(layers).GetNumWeights();
309+
});
310+
311+
numWeights += headRechannel.GetNumWeights();
312+
313+
return numWeights;
314+
}
315+
287316
void SetWeights(std::vector<float>::iterator& weights)
288317
{
289318
rechannel.SetWeights(weights);
@@ -357,26 +386,41 @@ namespace NeuralAudio
357386

358387
allocNum = std::get<layerIndex>(layerArrays).AllocBuffers(allocNum);
359388
});
360-
361389
}
362390

363-
void SetWeights(std::vector<float> weights)
391+
size_t GetNumWeights()
364392
{
365-
std::vector<float>::iterator it = weights.begin();
393+
size_t numWeights = 0;
366394

367395
ForEachIndex<sizeof...(LayerArrays)>([&](auto layerIndex)
368396
{
369-
std::get<layerIndex>(layerArrays).SetWeights(it);
397+
numWeights += std::get<layerIndex>(layerArrays).GetNumWeights();
370398
});
371399

372-
headScale = *(it++);
400+
numWeights++; // headScale;
401+
402+
return numWeights;
403+
}
404+
405+
void SetWeights(std::vector<float> weights)
406+
{
407+
size_t numWeights = GetNumWeights();
373408

374-
if (std::distance(weights.begin(), it) != (long)weights.size())
409+
if (numWeights != weights.size())
375410
{
376411
std::stringstream str;
377-
str << "Wrong number of weights. Remaining: " << std::distance(weights.begin(), it);
412+
str << "Wrong number of weights. Expected " << numWeights << " but got " << weights.size();
378413
throw std::runtime_error(str.str());
379414
}
415+
416+
std::vector<float>::iterator it = weights.begin();
417+
418+
ForEachIndex<sizeof...(LayerArrays)>([&](auto layerIndex)
419+
{
420+
std::get<layerIndex>(layerArrays).SetWeights(it);
421+
});
422+
423+
headScale = *(it++);
380424
}
381425

382426
size_t GetMaxFrames()

0 commit comments

Comments
 (0)