@@ -15,10 +15,10 @@ under the European Union’s Horizon 2020 research and innovation programme
1515#include " ../../data/FluidDataSet.hpp"
1616#include " ../../data/FluidIndex.hpp"
1717#include " ../../data/FluidTensor.hpp"
18+ #include " ../../data/SimpleDataSampler.hpp"
1819#include " ../../data/TensorTypes.hpp"
1920#include < Eigen/Core>
2021#include < limits>
21- #include < random>
2222
2323namespace fluid {
2424namespace algorithm {
@@ -27,72 +27,63 @@ class SGD
2727{
2828 using ArrayXd = Eigen::ArrayXd;
2929 using ArrayXXd = Eigen::ArrayXXd;
30- using Permutation = Eigen::PermutationMatrix<Eigen::Dynamic, Eigen::Dynamic>;
3130
3231public:
33- explicit SGD () = default;
34- ~SGD () = default ;
3532
36- double train (MLP& model, const RealMatrixView in, RealMatrixView out,
33+ double train (MLP& model, InputRealMatrixView in, RealMatrixView out,
3734 index nIter, index batchSize, double learningRate,
3835 double momentum, double valFrac)
36+ {
37+ return train (model, in, out,
38+ SimpleDataSampler (in.rows (), batchSize, valFrac, true ), nIter,
39+ learningRate, momentum);
40+ }
41+
42+ template <typename Sampler>
43+ double train (MLP& model, InputRealMatrixView in, RealMatrixView out,
44+ Sampler&& loader, index nIter, double learningRate,
45+ double momentum)
3946 {
4047 using namespace _impl ;
4148 using namespace std ;
4249 using namespace Eigen ;
43- index nExamples = in.rows ();
44- index inputSize = in.cols ();
45- index outputSize = out.cols ();
46- ArrayXXd input = asEigen<Eigen::Array>(in);
47- ArrayXXd output = asEigen<Eigen::Array>(out);
48- Permutation valPerm (nExamples);
49- valPerm.setIdentity ();
50- shuffle (valPerm.indices ().data (),
51- valPerm.indices ().data () + valPerm.indices ().size (),
52- mt19937{random_device{}()});
53- input = valPerm * input.matrix ();
54- output = valPerm * output.matrix ();
55- index nVal = std::lround (nExamples * valFrac);
56- index nTrain = nExamples - nVal;
50+ index nExamples = in.rows ();
51+ // index inputSize = in.cols();
52+ index outputSize = out.cols ();
5753
58- ArrayXXd trainInput = input.block (0 , 0 , nTrain, inputSize);
59- ArrayXXd trainOutput = output.block (0 , 0 , nTrain, outputSize);
60- ArrayXXd valInput = input.block (nTrain, 0 , nVal, inputSize);
61- ArrayXXd valOutput = output.block (nTrain, 0 , nVal, outputSize);
54+ auto valIdx = loader.validationSet ();
55+ std::optional<ArrayXXd> valInput;
56+ std::optional<ArrayXXd> valOutput;
57+ if (valIdx)
58+ {
59+ valInput =
60+ ArrayXXd (asEigen<Eigen::Array>(in)(valIdx->col (0 ), Eigen::all));
61+ valOutput =
62+ ArrayXXd (asEigen<Eigen::Array>(out)(valIdx->col (1 ), Eigen::all));
63+ }
6264
63- Permutation iterPerm (nTrain);
64- iterPerm.setIdentity ();
6565 double error = 0 ;
6666 index patience = mInitialPatience ;
6767 double prevValLoss = std::numeric_limits<double >::max ();
6868 while (nIter-- > 0 )
6969 {
70- shuffle (iterPerm.indices ().data (),
71- iterPerm.indices ().data () + iterPerm.indices ().size (),
72- mt19937{random_device{}()});
73- ArrayXXd inPerm = iterPerm * trainInput.matrix ();
74- ArrayXXd outPerm = iterPerm * trainOutput.matrix ();
75- for (index batchStart = 0 ; batchStart < inPerm.rows ();
76- batchStart += batchSize)
70+ for (auto batch : loader)
7771 {
78- index thisBatchSize = (batchStart + batchSize) <= nTrain
79- ? batchSize
80- : nTrain - batchStart;
81- ArrayXXd batchIn =
82- inPerm.block (batchStart, 0 , thisBatchSize, inPerm.cols ());
72+ index thisBatchSize = batch->rows ();
73+ ArrayXXd batchIn = asEigen<Eigen::Array>(in)(batch->col (0 ), Eigen::all);
8374 ArrayXXd batchOut =
84- outPerm. block (batchStart, 0 , thisBatchSize, outPerm. cols () );
75+ asEigen<Eigen::Array>(out)(batch-> col ( 1 ), Eigen::all );
8576 ArrayXXd batchPred = ArrayXXd::Zero (thisBatchSize, outputSize);
8677 model.forward (batchIn, batchPred);
8778 ArrayXXd diff = batchPred - batchOut;
8879 model.backward (diff);
8980 model.update (learningRate, momentum);
9081 }
91- if (nVal > 0 )
82+ if (valIdx )
9283 {
93- ArrayXXd valPred = ArrayXXd::Zero (nVal , outputSize);
94- model.forward (valInput, valPred);
95- double valLoss = model.loss (valPred, valOutput);
84+ ArrayXXd valPred = ArrayXXd::Zero (valInput-> rows () , outputSize);
85+ model.forward (* valInput, valPred);
86+ double valLoss = model.loss (valPred, * valOutput);
9687 if (valLoss < prevValLoss)
9788 patience = mInitialPatience ;
9889 else
@@ -101,6 +92,13 @@ class SGD
10192 prevValLoss = valLoss;
10293 }
10394 }
95+
96+ auto trainingIdx = loader.trainingSet ();
97+ nExamples = trainingIdx->rows ();
98+ ArrayXXd input = asEigen<Eigen::Array>(in)(trainingIdx->col (0 ), Eigen::all);
99+ ArrayXXd output =
100+ asEigen<Eigen::Array>(out)(trainingIdx->col (1 ), Eigen::all);
101+
104102 ArrayXXd finalPred = ArrayXXd::Zero (nExamples, outputSize);
105103 model.forward (input, finalPred);
106104 bool isNan = !((finalPred == finalPred)).all ();
0 commit comments