Skip to content

Commit e04085f

Browse files
authored
Improve data loading and sampling for SGD (#296)
* Ensures source and target IDs are paired up in SGD for map-like sources * Can validate map-like sources * Can use a simple sampling scheme for non-map-like sources
1 parent efb0813 commit e04085f

File tree

12 files changed

+538
-76
lines changed

12 files changed

+538
-76
lines changed

include/flucoma/algorithms/public/MLP.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@ class MLP
3030
using ArrayXXd = Eigen::ArrayXXd;
3131

3232
public:
33-
explicit MLP() = default;
34-
~MLP() = default;
3533

3634
void init(index inputSize, index outputSize,
3735
FluidTensor<index, 1> hiddenSizes, index hiddenAct, index outputAct)

include/flucoma/algorithms/public/SGD.hpp

Lines changed: 40 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ under the European Union’s Horizon 2020 research and innovation programme
1515
#include "../../data/FluidDataSet.hpp"
1616
#include "../../data/FluidIndex.hpp"
1717
#include "../../data/FluidTensor.hpp"
18+
#include "../../data/SimpleDataSampler.hpp"
1819
#include "../../data/TensorTypes.hpp"
1920
#include <Eigen/Core>
2021
#include <limits>
21-
#include <random>
2222

2323
namespace fluid {
2424
namespace algorithm {
@@ -27,72 +27,63 @@ class SGD
2727
{
2828
using ArrayXd = Eigen::ArrayXd;
2929
using ArrayXXd = Eigen::ArrayXXd;
30-
using Permutation = Eigen::PermutationMatrix<Eigen::Dynamic, Eigen::Dynamic>;
3130

3231
public:
33-
explicit SGD() = default;
34-
~SGD() = default;
3532

36-
double train(MLP& model, const RealMatrixView in, RealMatrixView out,
33+
double train(MLP& model, InputRealMatrixView in, RealMatrixView out,
3734
index nIter, index batchSize, double learningRate,
3835
double momentum, double valFrac)
36+
{
37+
return train(model, in, out,
38+
SimpleDataSampler(in.rows(), batchSize, valFrac, true), nIter,
39+
learningRate, momentum);
40+
}
41+
42+
template <typename Sampler>
43+
double train(MLP& model, InputRealMatrixView in, RealMatrixView out,
44+
Sampler&& loader, index nIter, double learningRate,
45+
double momentum)
3946
{
4047
using namespace _impl;
4148
using namespace std;
4249
using namespace Eigen;
43-
index nExamples = in.rows();
44-
index inputSize = in.cols();
45-
index outputSize = out.cols();
46-
ArrayXXd input = asEigen<Eigen::Array>(in);
47-
ArrayXXd output = asEigen<Eigen::Array>(out);
48-
Permutation valPerm(nExamples);
49-
valPerm.setIdentity();
50-
shuffle(valPerm.indices().data(),
51-
valPerm.indices().data() + valPerm.indices().size(),
52-
mt19937{random_device{}()});
53-
input = valPerm * input.matrix();
54-
output = valPerm * output.matrix();
55-
index nVal = std::lround(nExamples * valFrac);
56-
index nTrain = nExamples - nVal;
50+
index nExamples = in.rows();
51+
// index inputSize = in.cols();
52+
index outputSize = out.cols();
5753

58-
ArrayXXd trainInput = input.block(0, 0, nTrain, inputSize);
59-
ArrayXXd trainOutput = output.block(0, 0, nTrain, outputSize);
60-
ArrayXXd valInput = input.block(nTrain, 0, nVal, inputSize);
61-
ArrayXXd valOutput = output.block(nTrain, 0, nVal, outputSize);
54+
auto valIdx = loader.validationSet();
55+
std::optional<ArrayXXd> valInput;
56+
std::optional<ArrayXXd> valOutput;
57+
if (valIdx)
58+
{
59+
valInput =
60+
ArrayXXd(asEigen<Eigen::Array>(in)(valIdx->col(0), Eigen::all));
61+
valOutput =
62+
ArrayXXd(asEigen<Eigen::Array>(out)(valIdx->col(1), Eigen::all));
63+
}
6264

63-
Permutation iterPerm(nTrain);
64-
iterPerm.setIdentity();
6565
double error = 0;
6666
index patience = mInitialPatience;
6767
double prevValLoss = std::numeric_limits<double>::max();
6868
while (nIter-- > 0)
6969
{
70-
shuffle(iterPerm.indices().data(),
71-
iterPerm.indices().data() + iterPerm.indices().size(),
72-
mt19937{random_device{}()});
73-
ArrayXXd inPerm = iterPerm * trainInput.matrix();
74-
ArrayXXd outPerm = iterPerm * trainOutput.matrix();
75-
for (index batchStart = 0; batchStart < inPerm.rows();
76-
batchStart += batchSize)
70+
for (auto batch : loader)
7771
{
78-
index thisBatchSize = (batchStart + batchSize) <= nTrain
79-
? batchSize
80-
: nTrain - batchStart;
81-
ArrayXXd batchIn =
82-
inPerm.block(batchStart, 0, thisBatchSize, inPerm.cols());
72+
index thisBatchSize = batch->rows();
73+
ArrayXXd batchIn = asEigen<Eigen::Array>(in)(batch->col(0), Eigen::all);
8374
ArrayXXd batchOut =
84-
outPerm.block(batchStart, 0, thisBatchSize, outPerm.cols());
75+
asEigen<Eigen::Array>(out)(batch->col(1), Eigen::all);
8576
ArrayXXd batchPred = ArrayXXd::Zero(thisBatchSize, outputSize);
8677
model.forward(batchIn, batchPred);
8778
ArrayXXd diff = batchPred - batchOut;
8879
model.backward(diff);
8980
model.update(learningRate, momentum);
9081
}
91-
if (nVal > 0)
82+
if (valIdx)
9283
{
93-
ArrayXXd valPred = ArrayXXd::Zero(nVal, outputSize);
94-
model.forward(valInput, valPred);
95-
double valLoss = model.loss(valPred, valOutput);
84+
ArrayXXd valPred = ArrayXXd::Zero(valInput->rows(), outputSize);
85+
model.forward(*valInput, valPred);
86+
double valLoss = model.loss(valPred, *valOutput);
9687
if (valLoss < prevValLoss)
9788
patience = mInitialPatience;
9889
else
@@ -101,6 +92,13 @@ class SGD
10192
prevValLoss = valLoss;
10293
}
10394
}
95+
96+
auto trainingIdx = loader.trainingSet();
97+
nExamples = trainingIdx->rows();
98+
ArrayXXd input = asEigen<Eigen::Array>(in)(trainingIdx->col(0), Eigen::all);
99+
ArrayXXd output =
100+
asEigen<Eigen::Array>(out)(trainingIdx->col(1), Eigen::all);
101+
104102
ArrayXXd finalPred = ArrayXXd::Zero(nExamples, outputSize);
105103
model.forward(input, finalPred);
106104
bool isNan = !((finalPred == finalPred)).all();

include/flucoma/clients/nrt/CommonResults.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ static const std::string FileRead{"Couldn't read file"};
4040
static const std::string FileWrite{"Couldn't write file"};
4141
static const std::string NotImplemented{"Not implemented"};
4242
static const std::string SizesDontMatch{"Sizes do not match"};
43+
static const std::string TooFewOutputPoints{"Not enough output points"};
4344
static const std::string DimensionsDontMatch{"Dimensions do not match"};
4445

4546
template <typename T>

include/flucoma/clients/nrt/MLPClassifierClient.hpp

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ under the European Union’s Horizon 2020 research and innovation programme
1616
#include "../../algorithms/public/LabelSetEncoder.hpp"
1717
#include "../../algorithms/public/MLP.hpp"
1818
#include "../../algorithms/public/SGD.hpp"
19+
#include "../../data/FluidDataSetSampler.hpp"
1920
#include <string>
2021

2122
namespace fluid {
@@ -149,8 +150,8 @@ class MLPClassifierClient : public FluidBaseClient,
149150
if (!targetClientPtr) return Error<double>(NoLabelSet);
150151
auto targetDataSet = targetClientPtr->getLabelSet();
151152
if (targetDataSet.size() == 0) return Error<double>(EmptyLabelSet);
152-
if (sourceDataSet.size() != targetDataSet.size())
153-
return Error<double>(SizesDontMatch);
153+
if (sourceDataSet.size() > targetDataSet.size())
154+
return Error<double>(TooFewOutputPoints);
154155

155156
mAlgorithm.encoder.fit(targetDataSet);
156157

@@ -162,24 +163,40 @@ class MLPClassifierClient : public FluidBaseClient,
162163
mAlgorithm.encoder.numLabels(), get<kHidden>(),
163164
get<kActivation>(), 1); // sigmoid output
164165
}
165-
mAlgorithm.mlp.setTrained(false);
166-
DataSet result(1);
167-
auto data = sourceDataSet.getData();
168-
auto tgt = targetDataSet.getData();
169-
170-
RealMatrix oneHot(targetDataSet.size(), mAlgorithm.encoder.numLabels());
171-
oneHot.fill(0);
172-
for (index i = 0; i < targetDataSet.size(); i++)
166+
167+
if (auto missingIDs = sourceDataSet.checkIDs(targetDataSet);
168+
missingIDs.size() == 0)
173169
{
174-
mAlgorithm.encoder.encodeOneHot(tgt.row(i)(0), oneHot.row(i));
175-
}
170+
mAlgorithm.mlp.setTrained(false);
171+
DataSet result(1);
172+
auto data = sourceDataSet.getData();
173+
auto tgt = targetDataSet.getData();
174+
175+
RealMatrix oneHot(targetDataSet.size(), mAlgorithm.encoder.numLabels());
176+
oneHot.fill(0);
177+
for (index i = 0; i < targetDataSet.size(); i++)
178+
{
179+
mAlgorithm.encoder.encodeOneHot(tgt.row(i)(0), oneHot.row(i));
180+
}
181+
182+
FluidDataSetSampler sampler(sourceDataSet, targetDataSet,
183+
get<kBatchSize>(), get<kVal>(), true);
176184

177-
algorithm::SGD sgd;
178-
double error =
179-
sgd.train(mAlgorithm.mlp, data, oneHot, get<kIter>(), get<kBatchSize>(),
180-
get<kRate>(), get<kMomentum>(), get<kVal>());
185+
algorithm::SGD sgd;
186+
double error = sgd.train(mAlgorithm.mlp, data, oneHot, sampler, get<kIter>(),
187+
get<kRate>(), get<kMomentum>());
181188

182-
return error;
189+
return error;
190+
}
191+
else
192+
{
193+
std::ostringstream oss;
194+
oss << "Can't train because these IDs are missing from the output data: ";
195+
std::copy(missingIDs.begin(), missingIDs.end() - 1,
196+
std::ostream_iterator<std::string>(oss, ","));
197+
oss << missingIDs.back();
198+
return {Result::Status::kError, oss.str()};
199+
}
183200
}
184201

185202
MessageResult<void> predict(InputDataSetClientRef srcClient,

include/flucoma/clients/nrt/MLPRegressorClient.hpp

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@ under the European Union’s Horizon 2020 research and innovation programme
1414
#include "NRTClient.hpp"
1515
#include "../../algorithms/public/MLP.hpp"
1616
#include "../../algorithms/public/SGD.hpp"
17+
#include "../../data/FluidDataSetSampler.hpp"
1718
#include <string>
1819

20+
1921
namespace fluid {
2022
namespace client {
2123
namespace mlpregressor {
@@ -108,8 +110,8 @@ class MLPRegressorClient : public FluidBaseClient,
108110
if (!targetClientPtr) return Error<double>(NoDataSet);
109111
auto targetDataSet = targetClientPtr->getDataSet();
110112
if (targetDataSet.size() == 0) return Error<double>(EmptyDataSet);
111-
if (sourceDataSet.size() != targetDataSet.size())
112-
return Error<double>(SizesDontMatch);
113+
if (sourceDataSet.size() > targetDataSet.size())
114+
return Error<double>(TooFewOutputPoints);
113115
index outputAct = get<kOutputActivation>() == -1 ? get<kActivation>()
114116
: get<kOutputActivation>();
115117
if (!mAlgorithm.initialized() ||
@@ -121,15 +123,29 @@ class MLPRegressorClient : public FluidBaseClient,
121123
get<kHidden>(), get<kActivation>(), outputAct);
122124
}
123125

124-
mAlgorithm.setTrained(false);
125-
DataSet result(1);
126-
auto data = sourceDataSet.getData();
127-
auto tgt = targetDataSet.getData();
128-
algorithm::SGD sgd;
129-
double error =
130-
sgd.train(mAlgorithm, data, tgt, get<kIter>(), get<kBatchSize>(),
131-
get<kRate>(), get<kMomentum>(), get<kVal>());
132-
return error;
126+
if (auto missingIDs = sourceDataSet.checkIDs(targetDataSet);
127+
missingIDs.size() == 0)
128+
{
129+
mAlgorithm.setTrained(false);
130+
DataSet result(1);
131+
auto data = sourceDataSet.getData();
132+
auto tgt = targetDataSet.getData();
133+
FluidDataSetSampler sampler(sourceDataSet, targetDataSet,
134+
get<kBatchSize>(), get<kVal>(), true);
135+
algorithm::SGD sgd;
136+
double error = sgd.train(mAlgorithm, data, tgt, sampler, get<kIter>(),
137+
get<kRate>(), get<kMomentum>());
138+
return error;
139+
}
140+
else
141+
{
142+
std::ostringstream oss;
143+
oss << "Can't train because these IDs are missing from the output data: ";
144+
std::copy(missingIDs.begin(), missingIDs.end() - 1,
145+
std::ostream_iterator<std::string>(oss, ","));
146+
oss << missingIDs.back();
147+
return {Result::Status::kError, oss.str()};
148+
}
133149
}
134150

135151
MessageResult<void> predict(InputDataSetClientRef srcClient,

include/flucoma/data/FluidDataSet.hpp

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ namespace fluid {
1313
template <typename idType, typename dataType, index N>
1414
class FluidDataSet
1515
{
16+
template<typename, typename, index>
17+
friend class FluidDataSet;
1618

1719
public:
1820
explicit FluidDataSet() = default;
@@ -64,7 +66,7 @@ class FluidDataSet
6466
}
6567
}
6668

67-
bool add(idType const& id, FluidTensorView<dataType, N> point)
69+
bool add(idType const& id, FluidTensorView<const dataType, N> point)
6870
{
6971
assert(sameExtents(mDim, point.descriptor()));
7072
index pos = mData.rows();
@@ -196,6 +198,40 @@ class FluidDataSet
196198
return result.str();
197199
}
198200

201+
template<typename T, index M>
202+
auto indexMap(FluidDataSet<idType, T, M> const& x) const
203+
-> std::pair<std::vector<index>,std::vector<index>>
204+
{
205+
using std::pair, std::vector, std::begin, std::end;
206+
207+
pair<vector<index>, vector<index>> result;
208+
result.first.reserve(asUnsigned(x.size()));
209+
result.second.reserve(asUnsigned(x.size()));
210+
211+
auto firstID = begin(x.getIds());
212+
auto lastID = end(x.getIds());
213+
214+
std::transform(firstID, lastID, std::back_inserter(result.first),
215+
[this](auto const& id) { return mIndex.at(id); });
216+
std::transform(firstID, lastID, std::back_inserter(result.second),
217+
[&x](auto const& id) { return x.mIndex.at(id); });
218+
219+
return result;
220+
}
221+
222+
template <class U, index M>
223+
std::vector<idType> checkIDs(FluidDataSet<idType, U, M> const& other) const
224+
{
225+
std::vector<idType> result;
226+
227+
std::for_each(mIndex.begin(), mIndex.end(), [&result, &other](auto& item) {
228+
if (other.mIndex.find(item.first) == other.mIndex.end())
229+
result.push_back(item.first);
230+
});
231+
232+
return result;
233+
}
234+
199235
private:
200236
void initFromData()
201237
{
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
Part of the Fluid Corpus Manipulation Project (http://www.flucoma.org/)
3+
Copyright University of Huddersfield.
4+
Licensed under the BSD-3 License.
5+
See license.md file in the project root for full license information.
6+
This project has received funding from the European Research Council (ERC)
7+
under the European Union’s Horizon 2020 research and innovation programme
8+
(grant agreement No 725899).
9+
*/
10+
11+
/*
12+
Pairing two unordered DataSets for supervised use
13+
*/
14+
#pragma once
15+
16+
#include "FluidDataSet.hpp"
17+
#include "FluidIndex.hpp"
18+
#include "detail/DataSampler.hpp"
19+
#include <optional>
20+
#include <random>
21+
22+
23+
namespace fluid {
24+
25+
class FluidDataSetSampler : public detail::DataSampler<FluidDataSetSampler>
26+
{
27+
friend detail::DataSampler<FluidDataSetSampler>;
28+
29+
std::pair<std::vector<index>, std::vector<index>> mIdxMaps;
30+
31+
template <class InputIter>
32+
FluidTensorView<index, 2> map(InputIter start, InputIter end,
33+
FluidTensorView<index, 2> dst)
34+
{
35+
using std::begin;
36+
auto inputSamples = dst.col(0);
37+
auto outputSamples = dst.col(1);
38+
39+
transform(start, end, begin(inputSamples),
40+
[&idx = mIdxMaps](index i) { return idx.first[asUnsigned(i)]; });
41+
transform(start, end, begin(outputSamples),
42+
[&idx = mIdxMaps](index i) { return idx.second[asUnsigned(i)]; });
43+
return dst;
44+
}
45+
46+
public:
47+
template <typename DataSetA, typename DataSetB>
48+
FluidDataSetSampler(DataSetA const& in, DataSetB const& out, index batchSize,
49+
double validationFraction, bool shuffle = true)
50+
: detail::DataSampler<FluidDataSetSampler>(in.size(), batchSize,
51+
validationFraction, shuffle),
52+
mIdxMaps{in.indexMap(out)}
53+
{}
54+
};
55+
} // namespace fluid

0 commit comments

Comments
 (0)