66#include < iostream>
77#include < vector>
88
9+ namespace fluid {
10+
911TEST_CASE (" NMF is repeatable with user-supplied random seed" )
1012{
1113
12- using fluid:: algorithm::NMF;
13- using Tensor = fluid:: FluidTensor<double , 2 >;
14+ using algorithm::NMF;
15+ using Tensor = FluidTensor<double , 2 >;
1416 NMF algo;
1517
1618 Tensor input{{1 , 2 , 3 }, {4 , 5 , 6 }, {7 , 8 , 9 }};
@@ -42,3 +44,31 @@ TEST_CASE("NMF is repeatable with user-supplied random seed")
4244 REQUIRE_THAT (Vs[1 ], !RangeEquals (Vs[2 ]));
4345 }
4446}
47+
48+ TEST_CASE (" NMF processFrame() is repeatable with user-supplied random seed" )
49+ {
50+ using fluid::algorithm::NMF;
51+ using Tensor = fluid::FluidTensor<double , 2 >;
52+ using Vector = fluid::FluidTensor<double , 1 >;
53+ NMF algo;
54+
55+ Vector input{{1 , 0 , 1 , 0 }};
56+ Tensor bases{{0 , 0 , 1 , 0 }, {1 , 0 , 0 , 0 }};
57+ Vector v (4 );
58+
59+ std::vector outputs (3 , Vector (2 ));
60+
61+ index nIter{0 };
62+ algo.processFrame (input, bases, outputs[0 ], nIter, v, 42 ,
63+ FluidDefaultAllocator ());
64+ algo.processFrame (input, bases, outputs[1 ], nIter, v, 42 ,
65+ FluidDefaultAllocator ());
66+ algo.processFrame (input, bases, outputs[2 ], nIter, v, 7863 ,
67+ FluidDefaultAllocator ());
68+
69+ using Catch::Matchers::RangeEquals;
70+
71+ REQUIRE_THAT (outputs[1 ], RangeEquals (outputs[0 ]));
72+ REQUIRE_THAT (outputs[1 ], !RangeEquals (outputs[2 ]));
73+ }
74+ } // namespace fluid
0 commit comments