11#include " ComputeNeighborhoods.hpp"
22
3+ #include " simplnx/Common/Array.hpp"
4+ #include " simplnx/Common/Range.hpp"
5+ #include " simplnx/Common/Result.hpp"
6+ #include " simplnx/Common/Types.hpp"
7+ #include " simplnx/DataStructure/AbstractDataStore.hpp"
38#include " simplnx/DataStructure/DataArray.hpp"
4- #include " simplnx/DataStructure/DataGroup .hpp"
9+ #include " simplnx/DataStructure/DataStructure .hpp"
510#include " simplnx/DataStructure/Geometry/ImageGeom.hpp"
11+ #include " simplnx/Filter/IFilter.hpp"
12+ #include " simplnx/Utilities/MessageHelper.hpp"
613#include " simplnx/Utilities/ParallelDataAlgorithm.hpp"
714
8- #include < cmath>
15+ #include < fmt/format.h>
16+
17+ #include < atomic>
18+ #include < functional>
19+ #include < mutex>
20+ #include < unordered_map>
21+ #include < vector>
922
1023using namespace nx ::core;
1124namespace
1225{
26+ struct BinKey
27+ {
28+ int64 bx = 0 ;
29+ int64 by = 0 ;
30+ int64 bz = 0 ;
31+
32+ BinKey (int64 x, int64 y, int64 z)
33+ : bx(x)
34+ , by(y)
35+ , bz(z)
36+ {
37+ }
38+
39+ BinKey (const std::vector<int64>& bins, usize tupleIndex)
40+ {
41+ bx = bins[3 * tupleIndex + 0 ];
42+ by = bins[3 * tupleIndex + 1 ];
43+ bz = bins[3 * tupleIndex + 2 ];
44+ }
45+
46+ bool operator ==(const BinKey& other) const noexcept
47+ {
48+ return bx == other.bx && by == other.by && bz == other.bz ;
49+ }
50+ };
51+
52+ struct BinKeyHasher
53+ {
54+ usize operator ()(const BinKey& key) const noexcept
55+ {
56+ // simple hash combine
57+ const usize h1 = std::hash<int64>{}(key.bx );
58+ const usize h2 = std::hash<int64>{}(key.by );
59+ const usize h3 = std::hash<int64>{}(key.bz );
60+ usize seed = h1;
61+ seed ^= h2 + 0x9e3779b97f4a7c15ULL + (seed << 6 ) + (seed >> 2 );
62+ seed ^= h3 + 0x9e3779b97f4a7c15ULL + (seed << 6 ) + (seed >> 2 );
63+ return seed;
64+ }
65+ };
66+
1367class ComputeNeighborhoodsImpl
1468{
1569public:
16- ComputeNeighborhoodsImpl (ComputeNeighborhoods* filter, usize totalFeatures, const std::vector< int64_t >& bins , const std::vector<float >& criticalDistance, const std::atomic_bool& shouldCancel ,
17- ProgressMessageHelper& progressMessageHelper)
70+ ComputeNeighborhoodsImpl (ComputeNeighborhoods* filter, const nx::core::AbstractDataStore< float >& centroids , const std::vector<int64 >& bins, float32 avgDiam, float32 multiplesOfAverage ,
71+ const std::atomic_bool& shouldCancel, ProgressMessageHelper& progressMessageHelper)
1872 : m_Filter(filter)
19- , m_TotalFeatures(totalFeatures )
73+ , m_Centroids(centroids )
2074 , m_Bins(bins)
21- , m_CriticalDistance(criticalDistance)
75+ , m_AvgDiam(avgDiam)
76+ , m_MultiplesOfAverage(multiplesOfAverage)
2277 , m_ShouldCancel(shouldCancel)
2378 , m_ProgressMessageHelper(progressMessageHelper)
2479 {
2580 }
2681
2782 void convert (usize start, usize end) const
2883 {
29- int64 bin1x, bin2x, bin1y, bin2y, bin1z, bin2z;
30- float32 dBinX, dBinY, dBinZ;
31- float32 criticalDistance1, criticalDistance2;
32-
33- auto increment = static_cast <float64>(end - start) / 100.0 ;
34- float64 incCount = 0.0 ;
35- // NEVER start at 0.
36- if (start == 0 )
84+ const auto increment = static_cast <int64>((end - start) / 100.0 );
85+ int64 incCount = 0.0 ;
86+
87+ const usize totalFeatures = m_Centroids.getNumberOfTuples ();
88+
89+ // 1. Build spatial grid: BinKey -> list of features
90+ std::unordered_map<BinKey, std::vector<usize>, BinKeyHasher> binToFeatures;
91+ binToFeatures.reserve (totalFeatures);
92+
93+ for (usize i = 1 ; i < totalFeatures; ++i) // assuming feature 0 is background
3794 {
38- start = 1 ;
95+ const BinKey key (m_Bins, i);
96+ binToFeatures[key].push_back (i);
3997 }
4098
99+ // 2. Precompute radius info
100+ const float32 radius = m_AvgDiam * m_MultiplesOfAverage / 2 .0f ;
101+ const float32 radiusSq = radius * radius;
102+ const int64 k = static_cast <int64>(std::ceil (m_MultiplesOfAverage));
103+
41104 ProgressMessenger progressMessenger = m_ProgressMessageHelper.createProgressMessenger ();
42- for (usize featureIdx = start; featureIdx < end; featureIdx ++)
105+ for (usize i = start; i < end; i ++)
43106 {
44107 incCount++;
45108 if (incCount >= increment)
@@ -52,31 +115,56 @@ class ComputeNeighborhoodsImpl
52115 {
53116 return ;
54117 }
118+ // (a) Get feature's i position
119+ const float32 xi = m_Centroids[3 * i + 0 ];
120+ const float32 yi = m_Centroids[3 * i + 1 ];
121+ const float32 zi = m_Centroids[3 * i + 2 ];
55122
56- bin1x = m_Bins[ 3 * featureIdx];
57- bin1y = m_Bins[3 * featureIdx + 1 ];
58- bin1z = m_Bins[3 * featureIdx + 2 ];
59- criticalDistance1 = m_CriticalDistance[featureIdx ];
123+ // (b) Get its bin
124+ const int64 bx0 = m_Bins[3 * i + 0 ];
125+ const int64 by0 = m_Bins[3 * i + 1 ];
126+ const int64 bz0 = m_Bins[ 3 * i + 2 ];
60127
61- for (usize j = featureIdx + 1 ; j < m_TotalFeatures; j++)
128+ // (c) Scan all bins within +/- k in each dimension
129+ for (int64 dbx = -k; dbx <= k; ++dbx)
62130 {
63- bin2x = m_Bins[3 * j];
64- bin2y = m_Bins[3 * j + 1 ];
65- bin2z = m_Bins[3 * j + 2 ];
66- criticalDistance2 = m_CriticalDistance[j];
131+ for (int64 dby = -k; dby <= k; ++dby)
132+ {
133+ for (int64 dbz = -k; dbz <= k; ++dbz)
134+ {
135+ const BinKey nbKey{bx0 + dbx, by0 + dby, bz0 + dbz};
67136
68- dBinX = std::abs (static_cast <float32>(bin2x - bin1x));
69- dBinY = std::abs (static_cast <float32>(bin2y - bin1y));
70- dBinZ = std::abs (static_cast <float32>(bin2z - bin1z));
137+ auto it = binToFeatures.find (nbKey);
138+ if (it == binToFeatures.end ())
139+ {
140+ continue ; // no features in this bin
141+ }
71142
72- if (dBinX < criticalDistance1 && dBinY < criticalDistance1 && dBinZ < criticalDistance1)
73- {
74- m_Filter->updateNeighborHood (featureIdx, j);
75- }
143+ const std::vector<usize>& candidates = it->second ;
76144
77- if (dBinX < criticalDistance2 && dBinY < criticalDistance2 && dBinZ < criticalDistance2)
78- {
79- m_Filter->updateNeighborHood (j, featureIdx);
145+ // (d) Check actual distances to candidates in this bin
146+ for (const usize j : candidates)
147+ {
148+ if (j == i)
149+ {
150+ continue ; // skip self
151+ }
152+
153+ const float32 xj = m_Centroids[3 * j + 0 ];
154+ const float32 yj = m_Centroids[3 * j + 1 ];
155+ const float32 zj = m_Centroids[3 * j + 2 ];
156+
157+ const float32 dx = xi - xj;
158+ const float32 dy = yi - yj;
159+ const float32 dz = zi - zj;
160+
161+ const float32 distSq = dx * dx + dy * dy + dz * dz;
162+ if (distSq <= radiusSq)
163+ {
164+ m_Filter->updateNeighborHood (i, j);
165+ }
166+ }
167+ }
80168 }
81169 }
82170 }
@@ -90,9 +178,10 @@ class ComputeNeighborhoodsImpl
90178
91179private:
92180 ComputeNeighborhoods* m_Filter = nullptr ;
93- usize m_TotalFeatures = 0 ;
181+ const nx::core::AbstractDataStore< float >& m_Centroids ;
94182 const std::vector<int64>& m_Bins;
95- const std::vector<float32>& m_CriticalDistance;
183+ float32 m_AvgDiam;
184+ float32 m_MultiplesOfAverage;
96185 const std::atomic_bool& m_ShouldCancel;
97186 ProgressMessageHelper& m_ProgressMessageHelper;
98187};
@@ -111,16 +200,10 @@ ComputeNeighborhoods::ComputeNeighborhoods(DataStructure& dataStructure, const I
111200// -----------------------------------------------------------------------------
112201ComputeNeighborhoods::~ComputeNeighborhoods () noexcept = default ;
113202
114- // -----------------------------------------------------------------------------
115- const std::atomic_bool& ComputeNeighborhoods::getCancel ()
116- {
117- return m_ShouldCancel;
118- }
119-
120203// -----------------------------------------------------------------------------
121204void ComputeNeighborhoods::updateNeighborHood (usize sourceIndex, usize destIndex)
122205{
123- const std::lock_guard<std::mutex> lock (m_Mutex);
206+ const std::scoped_lock lock (m_Mutex);
124207 (*m_Neighborhoods)[sourceIndex].inc ();
125208 m_LocalNeighborhoodList[sourceIndex].push_back (static_cast <int32_t >(destIndex));
126209}
@@ -129,60 +212,59 @@ void ComputeNeighborhoods::updateNeighborHood(usize sourceIndex, usize destIndex
129212Result<> ComputeNeighborhoods::operator ()()
130213{
131214 // m_ProgressCounter initialized to zero on filter creation
132- std::vector<float32> criticalDistance;
133-
134215 auto multiplesOfAverage = m_InputValues->MultiplesOfAverage ;
135216 const auto & equivalentDiameters = m_DataStructure.getDataAs <Float32Array>(m_InputValues->EquivalentDiametersArrayPath )->getDataStoreRef ();
136217 const auto & centroids = m_DataStructure.getDataAs <Float32Array>(m_InputValues->CentroidsArrayPath )->getDataStoreRef ();
137218
138219 m_Neighborhoods = m_DataStructure.getDataAs <Int32Array>(m_InputValues->NeighborhoodsArrayName );
139220
140- usize totalFeatures = equivalentDiameters.getNumberOfTuples ();
221+ const usize totalFeatures = equivalentDiameters.getNumberOfTuples ();
141222
142223 ProgressMessageHelper progressMessageHelper = m_MessageHelper.createProgressMessageHelper ();
143224 progressMessageHelper.setMaxProgresss (totalFeatures);
144225 progressMessageHelper.setProgressMessageTemplate (" Finding Feature Neighborhoods: {:.2f}%" );
145226
146227 m_LocalNeighborhoodList.resize (totalFeatures);
147- criticalDistance.resize (totalFeatures);
148228
149- float32 aveDiam = 0 .0f ;
229+ // (a) This section finds the average equivalent spherical (ESD) diameter of ALL features
230+ float32 avgDiameter = 0 .0f ;
150231 for (usize i = 1 ; i < totalFeatures; i++)
151232 {
152233 (*m_Neighborhoods)[i] = 0 ;
153- aveDiam += equivalentDiameters[i];
154- criticalDistance[i] = equivalentDiameters[i] * multiplesOfAverage;
155- }
156- aveDiam /= static_cast <float32>(totalFeatures);
157- for (usize i = 1 ; i < totalFeatures; i++)
158- {
159- criticalDistance[i] /= aveDiam;
234+ avgDiameter += equivalentDiameters[i];
160235 }
236+ avgDiameter /= static_cast <float32>(totalFeatures);
237+ m_MessageHandler (IFilter::Message::Type::Info, fmt::format (" Feature Average Diameter: '{}'" , avgDiameter));
161238
239+ // (c) We are going to place each feature's centroid into a bin in the normalized 3D space.
240+ // The centroid is normalized by the Average Diameter
162241 std::vector<int64> bins (3 * totalFeatures, 0 );
163242 FloatVec3 origin = m_DataStructure.getDataAs <ImageGeom>(m_InputValues->InputImageGeometry )->getOrigin ();
164243 for (usize i = 1 ; i < totalFeatures; i++)
165244 {
166- float32 x = centroids[3 * i];
167- float32 y = centroids[3 * i + 1 ];
168- float32 z = centroids[3 * i + 2 ];
169- bins[3 * i] = static_cast <int64>((x - origin[0 ]) / aveDiam); // x-Bin
170- bins[3 * i + 1 ] = static_cast <int64>((y - origin[1 ]) / aveDiam); // y-Bin
171- bins[3 * i + 2 ] = static_cast <int64>((z - origin[2 ]) / aveDiam); // z-Bin
245+ const float32 x = centroids[3 * i];
246+ const float32 y = centroids[3 * i + 1 ];
247+ const float32 z = centroids[3 * i + 2 ];
248+ bins[3 * i] = static_cast <int64>((x - origin[0 ]) / avgDiameter); // x-Bin
249+ bins[3 * i + 1 ] = static_cast <int64>((y - origin[1 ]) / avgDiameter); // y-Bin
250+ bins[3 * i + 2 ] = static_cast <int64>((z - origin[2 ]) / avgDiameter); // z-Bin
251+ }
252+ if (m_ShouldCancel)
253+ {
254+ return {};
172255 }
173-
174256 ParallelDataAlgorithm parallelAlgorithm;
175257 parallelAlgorithm.setRange (Range (0 , totalFeatures));
176258 parallelAlgorithm.setParallelizationEnabled (true );
177- parallelAlgorithm.execute (ComputeNeighborhoodsImpl (this , totalFeatures , bins, criticalDistance , m_ShouldCancel, progressMessageHelper));
259+ parallelAlgorithm.execute (ComputeNeighborhoodsImpl (this , centroids , bins, avgDiameter, multiplesOfAverage , m_ShouldCancel, progressMessageHelper));
178260
179261 // Output Variables
180262 auto & outputNeighborList = m_DataStructure.getDataRefAs <NeighborList<int32>>(m_InputValues->NeighborhoodListArrayName );
181263 // Set the vector for each list into the NeighborList Object
182264 for (usize i = 1 ; i < totalFeatures; i++)
183265 {
184266 // Construct a shared vector<int32> through the std::vector<> copy constructor.
185- NeighborList<int32>::SharedVectorType sharedMisOrientationList (new std::vector<int32>(m_LocalNeighborhoodList[i]));
267+ const NeighborList<int32>::SharedVectorType sharedMisOrientationList (new std::vector<int32>(m_LocalNeighborhoodList[i]));
186268 outputNeighborList.setList (static_cast <int32>(i), sharedMisOrientationList);
187269 }
188270
0 commit comments