Skip to content

Commit 3bc22a6

Browse files
committed
Add NMF::processFrame random seed test
1 parent 644e9a3 commit 3bc22a6

File tree

1 file changed

+32
-2
lines changed

1 file changed

+32
-2
lines changed

tests/algorithms/public/TestNMF.cpp

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
#include <iostream>
77
#include <vector>
88

9+
namespace fluid {
10+
911
TEST_CASE("NMF is repeatable with user-supplied random seed")
1012
{
1113

12-
using fluid::algorithm::NMF;
13-
using Tensor = fluid::FluidTensor<double, 2>;
14+
using algorithm::NMF;
15+
using Tensor = FluidTensor<double, 2>;
1416
NMF algo;
1517

1618
Tensor input{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}};
@@ -42,3 +44,31 @@ TEST_CASE("NMF is repeatable with user-supplied random seed")
4244
REQUIRE_THAT(Vs[1], !RangeEquals(Vs[2]));
4345
}
4446
}
47+
48+
TEST_CASE("NMF processFrame() is repeatable with user-supplied random seed")
49+
{
50+
using fluid::algorithm::NMF;
51+
using Tensor = fluid::FluidTensor<double, 2>;
52+
using Vector = fluid::FluidTensor<double, 1>;
53+
NMF algo;
54+
55+
Vector input{{1, 0, 1, 0}};
56+
Tensor bases{{0, 0, 1, 0}, {1, 0, 0, 0}};
57+
Vector v(4);
58+
59+
std::vector outputs(3, Vector(2));
60+
61+
index nIter{0};
62+
algo.processFrame(input, bases, outputs[0], nIter, v, 42,
63+
FluidDefaultAllocator());
64+
algo.processFrame(input, bases, outputs[1], nIter, v, 42,
65+
FluidDefaultAllocator());
66+
algo.processFrame(input, bases, outputs[2], nIter, v, 7863,
67+
FluidDefaultAllocator());
68+
69+
using Catch::Matchers::RangeEquals;
70+
71+
REQUIRE_THAT(outputs[1], RangeEquals(outputs[0]));
72+
REQUIRE_THAT(outputs[1], !RangeEquals(outputs[2]));
73+
}
74+
} // namespace fluid

0 commit comments

Comments
 (0)