Skip to content

Commit 0df1a42

Browse files
committed
BUG: New implementation for ComputeNeighborhoods filter
The previous implementation has issues with the binning system and will only find that a centroid is within a radius if both centroids fall into the same bin. New unit test file was created and hand verified by MAJ.
1 parent 26f0cb2 commit 0df1a42

4 files changed

Lines changed: 221 additions & 77 deletions

File tree

src/Plugins/SimplnxCore/src/SimplnxCore/Filters/Algorithms/ComputeNeighborhoods.cpp

Lines changed: 147 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,108 @@
11
#include "ComputeNeighborhoods.hpp"
22

3+
#include "simplnx/Common/Array.hpp"
4+
#include "simplnx/Common/Range.hpp"
5+
#include "simplnx/Common/Result.hpp"
6+
#include "simplnx/Common/Types.hpp"
7+
#include "simplnx/DataStructure/AbstractDataStore.hpp"
38
#include "simplnx/DataStructure/DataArray.hpp"
4-
#include "simplnx/DataStructure/DataGroup.hpp"
9+
#include "simplnx/DataStructure/DataStructure.hpp"
510
#include "simplnx/DataStructure/Geometry/ImageGeom.hpp"
11+
#include "simplnx/Filter/IFilter.hpp"
12+
#include "simplnx/Utilities/MessageHelper.hpp"
613
#include "simplnx/Utilities/ParallelDataAlgorithm.hpp"
714

8-
#include <cmath>
15+
#include <fmt/format.h>
16+
17+
#include <atomic>
18+
#include <functional>
19+
#include <mutex>
20+
#include <unordered_map>
21+
#include <vector>
922

1023
using namespace nx::core;
1124
namespace
1225
{
26+
struct BinKey
27+
{
28+
int64 bx = 0;
29+
int64 by = 0;
30+
int64 bz = 0;
31+
32+
BinKey(int64 x, int64 y, int64 z)
33+
: bx(x)
34+
, by(y)
35+
, bz(z)
36+
{
37+
}
38+
39+
BinKey(const std::vector<int64>& bins, usize tupleIndex)
40+
{
41+
bx = bins[3 * tupleIndex + 0];
42+
by = bins[3 * tupleIndex + 1];
43+
bz = bins[3 * tupleIndex + 2];
44+
}
45+
46+
bool operator==(const BinKey& other) const noexcept
47+
{
48+
return bx == other.bx && by == other.by && bz == other.bz;
49+
}
50+
};
51+
52+
struct BinKeyHasher
53+
{
54+
usize operator()(const BinKey& key) const noexcept
55+
{
56+
// simple hash combine
57+
const usize h1 = std::hash<int64>{}(key.bx);
58+
const usize h2 = std::hash<int64>{}(key.by);
59+
const usize h3 = std::hash<int64>{}(key.bz);
60+
usize seed = h1;
61+
seed ^= h2 + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2);
62+
seed ^= h3 + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2);
63+
return seed;
64+
}
65+
};
66+
1367
class ComputeNeighborhoodsImpl
1468
{
1569
public:
16-
ComputeNeighborhoodsImpl(ComputeNeighborhoods* filter, usize totalFeatures, const std::vector<int64_t>& bins, const std::vector<float>& criticalDistance, const std::atomic_bool& shouldCancel,
17-
ProgressMessageHelper& progressMessageHelper)
70+
ComputeNeighborhoodsImpl(ComputeNeighborhoods* filter, const nx::core::AbstractDataStore<float>& centroids, const std::vector<int64>& bins, float32 avgDiam, float32 multiplesOfAverage,
71+
const std::atomic_bool& shouldCancel, ProgressMessageHelper& progressMessageHelper)
1872
: m_Filter(filter)
19-
, m_TotalFeatures(totalFeatures)
73+
, m_Centroids(centroids)
2074
, m_Bins(bins)
21-
, m_CriticalDistance(criticalDistance)
75+
, m_AvgDiam(avgDiam)
76+
, m_MultiplesOfAverage(multiplesOfAverage)
2277
, m_ShouldCancel(shouldCancel)
2378
, m_ProgressMessageHelper(progressMessageHelper)
2479
{
2580
}
2681

2782
void convert(usize start, usize end) const
2883
{
29-
int64 bin1x, bin2x, bin1y, bin2y, bin1z, bin2z;
30-
float32 dBinX, dBinY, dBinZ;
31-
float32 criticalDistance1, criticalDistance2;
32-
33-
auto increment = static_cast<float64>(end - start) / 100.0;
34-
float64 incCount = 0.0;
35-
// NEVER start at 0.
36-
if(start == 0)
84+
const auto increment = static_cast<int64>((end - start) / 100.0);
85+
int64 incCount = 0.0;
86+
87+
const usize totalFeatures = m_Centroids.getNumberOfTuples();
88+
89+
// 1. Build spatial grid: BinKey -> list of features
90+
std::unordered_map<BinKey, std::vector<usize>, BinKeyHasher> binToFeatures;
91+
binToFeatures.reserve(totalFeatures);
92+
93+
for(usize i = 1; i < totalFeatures; ++i) // assuming feature 0 is background
3794
{
38-
start = 1;
95+
const BinKey key(m_Bins, i);
96+
binToFeatures[key].push_back(i);
3997
}
4098

99+
// 2. Precompute radius info
100+
const float32 radius = m_AvgDiam * m_MultiplesOfAverage / 2.0f;
101+
const float32 radiusSq = radius * radius;
102+
const int64 k = static_cast<int64>(std::ceil(m_MultiplesOfAverage));
103+
41104
ProgressMessenger progressMessenger = m_ProgressMessageHelper.createProgressMessenger();
42-
for(usize featureIdx = start; featureIdx < end; featureIdx++)
105+
for(usize i = start; i < end; i++)
43106
{
44107
incCount++;
45108
if(incCount >= increment)
@@ -52,31 +115,56 @@ class ComputeNeighborhoodsImpl
52115
{
53116
return;
54117
}
118+
// (a) Get feature's i position
119+
const float32 xi = m_Centroids[3 * i + 0];
120+
const float32 yi = m_Centroids[3 * i + 1];
121+
const float32 zi = m_Centroids[3 * i + 2];
55122

56-
bin1x = m_Bins[3 * featureIdx];
57-
bin1y = m_Bins[3 * featureIdx + 1];
58-
bin1z = m_Bins[3 * featureIdx + 2];
59-
criticalDistance1 = m_CriticalDistance[featureIdx];
123+
// (b) Get its bin
124+
const int64 bx0 = m_Bins[3 * i + 0];
125+
const int64 by0 = m_Bins[3 * i + 1];
126+
const int64 bz0 = m_Bins[3 * i + 2];
60127

61-
for(usize j = featureIdx + 1; j < m_TotalFeatures; j++)
128+
// (c) Scan all bins within +/- k in each dimension
129+
for(int64 dbx = -k; dbx <= k; ++dbx)
62130
{
63-
bin2x = m_Bins[3 * j];
64-
bin2y = m_Bins[3 * j + 1];
65-
bin2z = m_Bins[3 * j + 2];
66-
criticalDistance2 = m_CriticalDistance[j];
131+
for(int64 dby = -k; dby <= k; ++dby)
132+
{
133+
for(int64 dbz = -k; dbz <= k; ++dbz)
134+
{
135+
const BinKey nbKey{bx0 + dbx, by0 + dby, bz0 + dbz};
67136

68-
dBinX = std::abs(static_cast<float32>(bin2x - bin1x));
69-
dBinY = std::abs(static_cast<float32>(bin2y - bin1y));
70-
dBinZ = std::abs(static_cast<float32>(bin2z - bin1z));
137+
auto it = binToFeatures.find(nbKey);
138+
if(it == binToFeatures.end())
139+
{
140+
continue; // no features in this bin
141+
}
71142

72-
if(dBinX < criticalDistance1 && dBinY < criticalDistance1 && dBinZ < criticalDistance1)
73-
{
74-
m_Filter->updateNeighborHood(featureIdx, j);
75-
}
143+
const std::vector<usize>& candidates = it->second;
76144

77-
if(dBinX < criticalDistance2 && dBinY < criticalDistance2 && dBinZ < criticalDistance2)
78-
{
79-
m_Filter->updateNeighborHood(j, featureIdx);
145+
// (d) Check actual distances to candidates in this bin
146+
for(const usize j : candidates)
147+
{
148+
if(j == i)
149+
{
150+
continue; // skip self
151+
}
152+
153+
const float32 xj = m_Centroids[3 * j + 0];
154+
const float32 yj = m_Centroids[3 * j + 1];
155+
const float32 zj = m_Centroids[3 * j + 2];
156+
157+
const float32 dx = xi - xj;
158+
const float32 dy = yi - yj;
159+
const float32 dz = zi - zj;
160+
161+
const float32 distSq = dx * dx + dy * dy + dz * dz;
162+
if(distSq <= radiusSq)
163+
{
164+
m_Filter->updateNeighborHood(i, j);
165+
}
166+
}
167+
}
80168
}
81169
}
82170
}
@@ -90,9 +178,10 @@ class ComputeNeighborhoodsImpl
90178

91179
private:
92180
ComputeNeighborhoods* m_Filter = nullptr;
93-
usize m_TotalFeatures = 0;
181+
const nx::core::AbstractDataStore<float>& m_Centroids;
94182
const std::vector<int64>& m_Bins;
95-
const std::vector<float32>& m_CriticalDistance;
183+
float32 m_AvgDiam;
184+
float32 m_MultiplesOfAverage;
96185
const std::atomic_bool& m_ShouldCancel;
97186
ProgressMessageHelper& m_ProgressMessageHelper;
98187
};
@@ -111,16 +200,10 @@ ComputeNeighborhoods::ComputeNeighborhoods(DataStructure& dataStructure, const I
111200
// -----------------------------------------------------------------------------
112201
ComputeNeighborhoods::~ComputeNeighborhoods() noexcept = default;
113202

114-
// -----------------------------------------------------------------------------
115-
const std::atomic_bool& ComputeNeighborhoods::getCancel()
116-
{
117-
return m_ShouldCancel;
118-
}
119-
120203
// -----------------------------------------------------------------------------
121204
void ComputeNeighborhoods::updateNeighborHood(usize sourceIndex, usize destIndex)
122205
{
123-
const std::lock_guard<std::mutex> lock(m_Mutex);
206+
const std::scoped_lock lock(m_Mutex);
124207
(*m_Neighborhoods)[sourceIndex].inc();
125208
m_LocalNeighborhoodList[sourceIndex].push_back(static_cast<int32_t>(destIndex));
126209
}
@@ -129,60 +212,59 @@ void ComputeNeighborhoods::updateNeighborHood(usize sourceIndex, usize destIndex
129212
Result<> ComputeNeighborhoods::operator()()
130213
{
131214
// m_ProgressCounter initialized to zero on filter creation
132-
std::vector<float32> criticalDistance;
133-
134215
auto multiplesOfAverage = m_InputValues->MultiplesOfAverage;
135216
const auto& equivalentDiameters = m_DataStructure.getDataAs<Float32Array>(m_InputValues->EquivalentDiametersArrayPath)->getDataStoreRef();
136217
const auto& centroids = m_DataStructure.getDataAs<Float32Array>(m_InputValues->CentroidsArrayPath)->getDataStoreRef();
137218

138219
m_Neighborhoods = m_DataStructure.getDataAs<Int32Array>(m_InputValues->NeighborhoodsArrayName);
139220

140-
usize totalFeatures = equivalentDiameters.getNumberOfTuples();
221+
const usize totalFeatures = equivalentDiameters.getNumberOfTuples();
141222

142223
ProgressMessageHelper progressMessageHelper = m_MessageHelper.createProgressMessageHelper();
143224
progressMessageHelper.setMaxProgresss(totalFeatures);
144225
progressMessageHelper.setProgressMessageTemplate("Finding Feature Neighborhoods: {:.2f}%");
145226

146227
m_LocalNeighborhoodList.resize(totalFeatures);
147-
criticalDistance.resize(totalFeatures);
148228

149-
float32 aveDiam = 0.0f;
229+
// (a) This section finds the average equivalent spherical (ESD) diameter of ALL features
230+
float32 avgDiameter = 0.0f;
150231
for(usize i = 1; i < totalFeatures; i++)
151232
{
152233
(*m_Neighborhoods)[i] = 0;
153-
aveDiam += equivalentDiameters[i];
154-
criticalDistance[i] = equivalentDiameters[i] * multiplesOfAverage;
155-
}
156-
aveDiam /= static_cast<float32>(totalFeatures);
157-
for(usize i = 1; i < totalFeatures; i++)
158-
{
159-
criticalDistance[i] /= aveDiam;
234+
avgDiameter += equivalentDiameters[i];
160235
}
236+
avgDiameter /= static_cast<float32>(totalFeatures);
237+
m_MessageHandler(IFilter::Message::Type::Info, fmt::format("Feature Average Diameter: '{}'", avgDiameter));
161238

239+
// (c) We are going to place each feature's centroid into a bin in the normalized 3D space.
240+
// The centroid is normalized by the Average Diameter
162241
std::vector<int64> bins(3 * totalFeatures, 0);
163242
FloatVec3 origin = m_DataStructure.getDataAs<ImageGeom>(m_InputValues->InputImageGeometry)->getOrigin();
164243
for(usize i = 1; i < totalFeatures; i++)
165244
{
166-
float32 x = centroids[3 * i];
167-
float32 y = centroids[3 * i + 1];
168-
float32 z = centroids[3 * i + 2];
169-
bins[3 * i] = static_cast<int64>((x - origin[0]) / aveDiam); // x-Bin
170-
bins[3 * i + 1] = static_cast<int64>((y - origin[1]) / aveDiam); // y-Bin
171-
bins[3 * i + 2] = static_cast<int64>((z - origin[2]) / aveDiam); // z-Bin
245+
const float32 x = centroids[3 * i];
246+
const float32 y = centroids[3 * i + 1];
247+
const float32 z = centroids[3 * i + 2];
248+
bins[3 * i] = static_cast<int64>((x - origin[0]) / avgDiameter); // x-Bin
249+
bins[3 * i + 1] = static_cast<int64>((y - origin[1]) / avgDiameter); // y-Bin
250+
bins[3 * i + 2] = static_cast<int64>((z - origin[2]) / avgDiameter); // z-Bin
251+
}
252+
if(m_ShouldCancel)
253+
{
254+
return {};
172255
}
173-
174256
ParallelDataAlgorithm parallelAlgorithm;
175257
parallelAlgorithm.setRange(Range(0, totalFeatures));
176258
parallelAlgorithm.setParallelizationEnabled(true);
177-
parallelAlgorithm.execute(ComputeNeighborhoodsImpl(this, totalFeatures, bins, criticalDistance, m_ShouldCancel, progressMessageHelper));
259+
parallelAlgorithm.execute(ComputeNeighborhoodsImpl(this, centroids, bins, avgDiameter, multiplesOfAverage, m_ShouldCancel, progressMessageHelper));
178260

179261
// Output Variables
180262
auto& outputNeighborList = m_DataStructure.getDataRefAs<NeighborList<int32>>(m_InputValues->NeighborhoodListArrayName);
181263
// Set the vector for each list into the NeighborList Object
182264
for(usize i = 1; i < totalFeatures; i++)
183265
{
184266
// Construct a shared vector<int32> through the std::vector<> copy constructor.
185-
NeighborList<int32>::SharedVectorType sharedMisOrientationList(new std::vector<int32>(m_LocalNeighborhoodList[i]));
267+
const NeighborList<int32>::SharedVectorType sharedMisOrientationList(new std::vector<int32>(m_LocalNeighborhoodList[i]));
186268
outputNeighborList.setList(static_cast<int32>(i), sharedMisOrientationList);
187269
}
188270

src/Plugins/SimplnxCore/src/SimplnxCore/Filters/Algorithms/ComputeNeighborhoods.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@ class SIMPLNXCORE_EXPORT ComputeNeighborhoods
4141

4242
Result<> operator()();
4343

44-
const std::atomic_bool& getCancel();
45-
4644
void updateNeighborHood(usize sourceIndex, usize targetIndex);
4745

4846
private:

src/Plugins/SimplnxCore/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ if(EXISTS "${DREAM3D_DATA_DIR}" AND SIMPLNX_DOWNLOAD_TEST_FILES)
273273
download_test_data(DREAM3D_DATA_DIR ${DREAM3D_DATA_DIR} ARCHIVE_NAME ReadZeissTxmFileTest_v2.tar.gz SHA512 2d962655d919480fa7baf0c8720affbcbd1182e8482480f5413b53ead41787d35d0f15487219163ba0dcffad69be90e627730fc780107d80df3d97e14fb15cac)
274274
download_test_data(DREAM3D_DATA_DIR ${DREAM3D_DATA_DIR} ARCHIVE_NAME segment_features_neighbor_scheme_test.tar.gz SHA512 4700b42eb3a9058dc08520c786d0e6aa20e45bfa06233952462c0c8a428e437406f530125822f1742dfe57405be66e4f0fd5ac2f76db4a13379a7ae189d370cd)
275275
download_test_data(DREAM3D_DATA_DIR ${DREAM3D_DATA_DIR} ARCHIVE_NAME write_lammps_test.tar.gz SHA512 82bb5360b76e857f3233d37733c602f67fd2ac667e49b24741a70ab649e8046fb7905493df37d142808b740c2771fe7cdccd71c9d70679afafe398529ee5771e)
276+
download_test_data(DREAM3D_DATA_DIR ${DREAM3D_DATA_DIR} ARCHIVE_NAME compute_feature_neighborhoods.tar.gz SHA512 dda96546de1b924d8145e8d173d00605bbf7d311c8d3719edd9d698bb8a326009fcfdb885d2fe08204b8e8dcbebe2eb699d3255e25fe34cfc36b9f950a8e07da)
276277

277278
endif()
278279

0 commit comments

Comments
 (0)