Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
146 commits
Select commit Hold shift + click to select a range
bedcf4c
first commit
tarang-jain Jul 9, 2025
09f9a22
Merge branch 'branch-25.08' of https://github.com/rapidsai/cuvs into …
tarang-jain Jul 9, 2025
51836d4
index header
tarang-jain Jul 9, 2025
23ef877
populate functions;ivf_list type;kmeans_predict
tarang-jain Jul 11, 2025
a7fce8e
Merge branch 'branch-25.08' of https://github.com/rapidsai/cuvs into …
tarang-jain Jul 11, 2025
6a98a88
hamming_op
tarang-jain Jul 12, 2025
76c9ee5
Merge branch 'branch-25.08' into binary-kmeans
tarang-jain Jul 12, 2025
916a4cf
rm binary_ivf
tarang-jain Jul 12, 2025
8ec4d59
Merge branch 'binary-kmeans' of https://github.com/tarang-jain/cuvs i…
tarang-jain Jul 12, 2025
1941b2e
modify ivf_flat_build
tarang-jain Jul 12, 2025
cd00b83
rm binary_ivf_flat
tarang-jain Jul 12, 2025
4cffe84
rm unused
tarang-jain Jul 12, 2025
2bc9007
updates
tarang-jain Jul 12, 2025
7803850
quantize
tarang-jain Jul 12, 2025
ff7be4a
cleanup
tarang-jain Jul 12, 2025
3149192
pre-commit
tarang-jain Jul 12, 2025
dd1b0d4
update kmeans_predict
tarang-jain Jul 12, 2025
2b9bef4
src kmeans
tarang-jain Jul 12, 2025
6ec32d8
style
tarang-jain Jul 12, 2025
5c59753
corrections to logic
tarang-jain Jul 14, 2025
2271809
clang
tarang-jain Jul 14, 2025
29388e9
testing
tarang-jain Jul 14, 2025
4a491cd
correct tests
tarang-jain Jul 14, 2025
4c27acd
binary_index flag
tarang-jain Jul 14, 2025
7085405
correct mdspan,reduction_op
tarang-jain Jul 14, 2025
2075bd0
Merge branch 'branch-25.08' into binary-kmeans
tarang-jain Jul 15, 2025
d06b4f5
correct checks
tarang-jain Jul 15, 2025
13de163
Merge branch 'branch-25.08' of https://github.com/rapidsai/cuvs into …
tarang-jain Jul 15, 2025
e79d520
Merge branch 'binary-kmeans' of https://github.com/tarang-jain/cuvs i…
tarang-jain Jul 15, 2025
c7f3ade
Merge branch 'branch-25.10' of https://github.com/rapidsai/cuvs into …
tarang-jain Jul 31, 2025
e7fdd4c
Merge branch 'branch-25.10' into binary-kmeans
tarang-jain Aug 1, 2025
3ce25af
fused-kernel
tarang-jain Aug 1, 2025
1f565ef
Merge branch 'binary-kmeans' of https://github.com/tarang-jain/cuvs i…
tarang-jain Aug 1, 2025
2c2a91d
cleeanup;syntax
tarang-jain Aug 1, 2025
b9a6a65
cleanup
tarang-jain Aug 1, 2025
0f19648
temp-fix-compilation-errors
tarang-jain Aug 1, 2025
2de4bcc
Merge branch 'branch-25.10' into binary-kmeans
tarang-jain Aug 5, 2025
f572e0b
Merge branch 'branch-25.10' into binary-kmeans
tarang-jain Aug 5, 2025
66d8b94
uint32_t for storing dists
tarang-jain Aug 6, 2025
46f9676
Merge branch 'branch-25.10' of https://github.com/rapidsai/cuvs into …
tarang-jain Aug 6, 2025
fa2193f
Merge branch 'binary-kmeans' of https://github.com/tarang-jain/cuvs i…
tarang-jain Aug 6, 2025
8f18ce0
no use_norms check
tarang-jain Aug 6, 2025
c7084ac
coarse search
tarang-jain Aug 6, 2025
1ca868b
debug; fix dim() error
tarang-jain Aug 6, 2025
42859fb
Merge branch 'branch-25.10' into binary-kmeans
tarang-jain Aug 6, 2025
f0597fa
Merge branch 'branch-25.10' into binary-kmeans
tarang-jain Aug 12, 2025
ca97da6
Merge branch 'branch-25.10' into binary-kmeans
tarang-jain Aug 25, 2025
8ad530b
Merge branch 'branch-25.10' of https://github.com/rapidsai/cuvs into …
tarang-jain Aug 26, 2025
17449e4
debug
tarang-jain Sep 2, 2025
f8e6e86
Merge branch 'branch-25.10' of https://github.com/rapidsai/cuvs into …
tarang-jain Sep 2, 2025
6aa2401
fix float underflow
tarang-jain Sep 4, 2025
d1717f5
rm new file
tarang-jain Sep 4, 2025
3b284cd
Merge branch 'branch-25.10' of https://github.com/rapidsai/cuvs into …
tarang-jain Sep 4, 2025
f34103d
Merge branch 'branch-25.10' into binary-kmeans
tarang-jain Sep 5, 2025
86a6593
Merge branch 'branch-25.10' into binary-kmeans
tarang-jain Sep 11, 2025
ef37c59
Merge branch 'branch-25.10' into binary-kmeans
tarang-jain Sep 15, 2025
5e794fe
Merge branch 'branch-25.10' into binary-kmeans
tarang-jain Sep 16, 2025
359d458
Merge branch 'branch-25.10' into binary-kmeans
tarang-jain Sep 16, 2025
d999591
Merge branch 'branch-25.10' into binary-kmeans
tarang-jain Sep 17, 2025
f242561
bug fixes
tarang-jain Sep 17, 2025
5517ae3
Merge branch 'binary-kmeans' of https://github.com/tarang-jain/cuvs i…
tarang-jain Sep 17, 2025
4aa2881
Merge branch 'branch-25.12' into binary-kmeans
tarang-jain Sep 25, 2025
b0d78b6
Merge branch 'branch-25.12' into binary-kmeans
tarang-jain Sep 26, 2025
184f025
Merge branch 'branch-25.12' into binary-kmeans
tarang-jain Sep 29, 2025
113e242
Merge branch 'branch-25.12' into binary-kmeans
tarang-jain Oct 6, 2025
540db62
correct fused instantiation
tarang-jain Oct 8, 2025
d9eed14
Merge branch 'binary-kmeans' of https://github.com/tarang-jain/cuvs i…
tarang-jain Oct 8, 2025
7ceed82
Merge branch 'branch-25.12' into binary-kmeans
tarang-jain Oct 9, 2025
ac261be
Merge branch 'branch-25.12' into binary-kmeans
tarang-jain Oct 10, 2025
b66b8ee
cleanup
tarang-jain Oct 13, 2025
4a19ec1
Merge branch 'binary-kmeans' of https://github.com/tarang-jain/cuvs i…
tarang-jain Oct 13, 2025
80d5f9c
Merge branch 'branch-25.12' of https://github.com/rapidsai/cuvs into …
tarang-jain Oct 13, 2025
d2b1184
Merge branch 'branch-25.12' of https://github.com/rapidsai/cuvs into …
tarang-jain Oct 15, 2025
ab1083e
Merge branch 'branch-25.12' into binary-kmeans
tarang-jain Oct 16, 2025
9a26624
thorough equivalence testing checkpoint
tarang-jain Oct 16, 2025
8a2098f
Merge branch 'binary-kmeans' of https://github.com/tarang-jain/cuvs i…
tarang-jain Oct 16, 2025
d4aae95
cleanup-1
tarang-jain Oct 16, 2025
f5c61dc
cleanup-2
tarang-jain Oct 16, 2025
54ae1a2
cleanup-3
tarang-jain Oct 16, 2025
76fc1f4
cleanup-4 (style
tarang-jain Oct 16, 2025
a8a3a0a
cleanup-6 tests
tarang-jain Oct 17, 2025
05363f5
cleanup-7 (tests)
tarang-jain Oct 17, 2025
de3bcc1
cleanup-8 tests
tarang-jain Oct 17, 2025
a02e563
cleanup-8 (tests)
tarang-jain Oct 17, 2025
b6181a0
cleanup-8 (tests)
tarang-jain Oct 17, 2025
5f640c1
use raft ci artifacts from pr
tarang-jain Oct 17, 2025
cc0fa18
wheels artifacts from raft pr
tarang-jain Oct 17, 2025
de606fa
style
tarang-jain Oct 17, 2025
385d8c7
shell stlye check
tarang-jain Oct 17, 2025
ae49732
Remove RAFT Dask channel from conda packages script
tarang-jain Oct 17, 2025
566b247
fix visible sections in simt_kernel
tarang-jain Oct 17, 2025
a24d13b
Merge branch 'branch-25.12' of https://github.com/rapidsai/cuvs into …
tarang-jain Oct 17, 2025
b4ea915
raft artifact for python build
tarang-jain Oct 17, 2025
34d2ec5
Merge branch 'branch-25.12' of https://github.com/rapidsai/cuvs into …
tarang-jain Oct 17, 2025
55b13d8
add skip flags
tarang-jain Oct 20, 2025
f44a7d8
Merge branch 'branch-25.12' of https://github.com/rapidsai/cuvs into …
tarang-jain Oct 20, 2025
942dc96
Merge branch 'main' into binary-kmeans
tarang-jain Oct 22, 2025
8d99bc7
set binary_index_ flag explicitly
tarang-jain Oct 22, 2025
5153555
Merge branch 'main' into binary-kmeans
tarang-jain Oct 22, 2025
ca41dc8
Merge branch 'main' of https://github.com/rapidsai/cuvs into binary-k…
tarang-jain Oct 25, 2025
182dadf
correct recall threshold
tarang-jain Oct 25, 2025
99eea8a
merge upstream
tarang-jain Oct 28, 2025
5220772
style and correct tests
tarang-jain Oct 28, 2025
5fc6f4b
rm extra test
tarang-jain Oct 28, 2025
37378d7
Merge branch 'main' into binary-kmeans
tarang-jain Oct 28, 2025
b948a33
rm test
tarang-jain Oct 28, 2025
c5e07b1
Merge branch 'binary-kmeans' of https://github.com/tarang-jain/cuvs i…
tarang-jain Oct 28, 2025
cbf83b8
Merge branch 'main' into binary-kmeans
tarang-jain Nov 4, 2025
c14f26f
Merge branch 'main' into binary-kmeans
tarang-jain Nov 7, 2025
b3eac6f
Merge branch 'main' into binary-kmeans
tarang-jain Nov 10, 2025
03a01a1
Merge branch 'main' into binary-kmeans
achirkin Nov 12, 2025
7b409c7
on the fly dataset expansion
tarang-jain Nov 26, 2025
080a5ea
Merge branch 'main' into binary-kmeans
tarang-jain Nov 26, 2025
107a3e4
address PR reviews
tarang-jain Nov 26, 2025
1caabeb
Merge branch 'binary-kmeans' of https://github.com/tarang-jain/cuvs i…
tarang-jain Nov 26, 2025
3681320
simplify bitwise_decode_op
tarang-jain Nov 26, 2025
118a3d7
commit suggestion
tarang-jain Nov 26, 2025
3399656
clang
tarang-jain Nov 26, 2025
d4fe3fe
Merge branch 'main' into binary-kmeans
tarang-jain Dec 2, 2025
653364e
merge upstream
tarang-jain Dec 3, 2025
ef34908
address reviews
tarang-jain Dec 3, 2025
d12a211
style
tarang-jain Dec 3, 2025
91c6734
undo ci changes
tarang-jain Dec 3, 2025
42f1bb7
fix kmeans mapping op
tarang-jain Dec 3, 2025
e59a357
other fixes to kmeans for binary data
tarang-jain Dec 3, 2025
d13c0c9
fix compilation errors
tarang-jain Dec 4, 2025
8162293
fix compilation
tarang-jain Dec 4, 2025
fcae247
fix compilation
tarang-jain Dec 4, 2025
a9599bd
doc
tarang-jain Dec 4, 2025
ffbcdcc
Merge branch 'main' of https://github.com/rapidsai/cuvs into binary-k…
tarang-jain Dec 4, 2025
3ffba85
simplify ivf-flat build
tarang-jain Dec 4, 2025
c0a99e2
fix compilation errors
tarang-jain Dec 5, 2025
dbb6423
bug fixes
tarang-jain Dec 5, 2025
248911c
debug
tarang-jain Dec 5, 2025
0656ee6
more corrections to kmeans
tarang-jain Dec 5, 2025
1423356
Merge branch 'main' into binary-kmeans
tarang-jain Dec 19, 2025
a25ddac
Merge branch 'main' into binary-kmeans
tarang-jain Dec 22, 2025
e8a8152
Merge branch 'main' of https://github.com/rapidsai/cuvs into binary-k…
tarang-jain Dec 23, 2025
8feebb8
bug fixes
tarang-jain Dec 23, 2025
997ddde
debug
tarang-jain Dec 24, 2025
89b54a1
working impl;rm debug statements
tarang-jain Dec 24, 2025
07354d1
rm debug prints:
tarang-jain Dec 24, 2025
07e1837
Merge branch 'main' into binary-kmeans
tarang-jain Dec 30, 2025
510bafd
Merge branch 'release/26.02' into binary-kmeans
tarang-jain Jan 20, 2026
0ddce5d
Merge branch 'release/26.02' into binary-kmeans
tarang-jain Jan 23, 2026
d546471
Merge branch 'release/26.02' into binary-kmeans
tarang-jain Jan 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,7 @@ if(NOT BUILD_CPU_ONLY)
src/distance/detail/pairwise_matrix/dispatch_rbf.cu
src/distance/detail/pairwise_matrix/dispatch_l2_expanded_double_double_double_int64_t.cu
src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int64_t.cu
src/distance/detail/pairwise_matrix/dispatch_bitwise_hamming_uint8_t_uint32_t_uint32_t_int64_t.cu
src/distance/distance.cu
src/distance/pairwise_distance.cu
src/distance/sparse_distance.cu
Expand Down
9 changes: 9 additions & 0 deletions cpp/include/cuvs/cluster/kmeans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,15 @@ struct balanced_params : base_params {
* Number of training iterations
*/
uint32_t n_iters = 20;

/**
* If true, treats uint8_t input data as bit-packed binary data where each byte contains 8 bits.
* Bits are expanded on-the-fly to {-1, +1} floats during training.
* When enabled:
* - Input data dimension represents packed dimension (actual_dim / 8)
* - Output centroids dimension is expanded (packed_dim * 8)
*/
bool is_packed_binary = false;
};

/**
Expand Down
14 changes: 13 additions & 1 deletion cpp/include/cuvs/neighbors/ivf_flat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,11 @@ struct index : cuvs::neighbors::index {
raft::device_matrix_view<float, uint32_t, raft::row_major> centers() noexcept;
raft::device_matrix_view<const float, uint32_t, raft::row_major> centers() const noexcept;

/** packed k-means cluster centers corresponding to the lists [n_lists, dim] when the
* BitwiseHamming metric is selected */
raft::device_matrix_view<uint8_t, int64_t, raft::row_major> binary_centers() noexcept;
raft::device_matrix_view<const uint8_t, int64_t, raft::row_major> binary_centers() const noexcept;

/**
* (Optional) Precomputed norms of the `centers` w.r.t. the chosen distance metric [n_lists].
*
Expand All @@ -229,7 +234,10 @@ struct index : cuvs::neighbors::index {
/** Total length of the index. */
IdxT size() const noexcept;

/** Dimensionality of the data. */
/** Dimensionality of the data.
* @note For binary index, this returns the dimensionality of the byte dataset, which is the
* number of bits / 8.
*/
uint32_t dim() const noexcept;

/** Number of clusters/inverted lists. */
Expand All @@ -255,6 +263,8 @@ struct index : cuvs::neighbors::index {

void check_consistency();

bool binary_index() const noexcept;

private:
/**
* TODO: in theory, we can lift this to the template parameter and keep it at hardware maximum
Expand All @@ -267,7 +277,9 @@ struct index : cuvs::neighbors::index {
std::vector<std::shared_ptr<list_data<T, IdxT>>> lists_;
raft::device_vector<uint32_t, uint32_t> list_sizes_;
raft::device_matrix<float, uint32_t, raft::row_major> centers_;
raft::device_matrix<uint8_t, int64_t, raft::row_major> binary_centers_;
std::optional<raft::device_vector<float, uint32_t>> center_norms_;
bool binary_index_;

// Computed members
raft::device_vector<T*, uint32_t> data_ptrs_;
Expand Down
395 changes: 322 additions & 73 deletions cpp/src/cluster/detail/kmeans_balanced.cuh

Large diffs are not rendered by default.

12 changes: 8 additions & 4 deletions cpp/src/cluster/kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ void fit(const raft::resources& handle,
raft::device_matrix_view<MathT, IndexT> centroids,
MappingOpT mapping_op = raft::identity_op())
{
RAFT_EXPECTS(X.extent(1) == centroids.extent(1),
RAFT_EXPECTS(X.extent(1) == centroids.extent(1) ||
(params.is_packed_binary && X.extent(1) * 8 == centroids.extent(1)),
"Number of features in dataset and centroids are different");
RAFT_EXPECTS(static_cast<uint64_t>(X.extent(0)) * static_cast<uint64_t>(X.extent(1)) <=
static_cast<uint64_t>(std::numeric_limits<IndexT>::max()),
Expand Down Expand Up @@ -279,14 +280,16 @@ void calc_centers_and_sizes(const raft::resources& handle,
raft::device_matrix_view<MathT, IndexT> centroids,
raft::device_vector_view<CounterT, IndexT> cluster_sizes,
bool reset_counters = true,
bool is_packed_binary = false,
MappingOpT mapping_op = raft::identity_op())
{
RAFT_EXPECTS(X.extent(0) == labels.extent(0),
"Number of rows in dataset and labels are different");
RAFT_EXPECTS(X.extent(1) == centroids.extent(1),
"Number of features in dataset and centroids are different");
RAFT_EXPECTS(
is_packed_binary ? X.extent(1) * 8 == centroids.extent(1) : X.extent(1) == centroids.extent(1),
"Number of features in dataset and centroids are different");
RAFT_EXPECTS(centroids.extent(0) == cluster_sizes.extent(0),
"Number of rows in centroids and clusyer_sizes are different");
"Number of rows in centroids and cluster_sizes are different");

cuvs::cluster::kmeans::detail::calc_centers_and_sizes(
handle,
Expand All @@ -298,6 +301,7 @@ void calc_centers_and_sizes(const raft::resources& handle,
X.extent(0),
labels.data_handle(),
reset_counters,
is_packed_binary,
mapping_op,
raft::resource::get_workspace_resource(handle));
}
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/distance/detail/distance_ops/all_ops.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

Expand All @@ -9,6 +9,7 @@
#include "cutlass.cuh"

// The distance operations:
#include "../distance_ops/bitwise_hamming.cuh"
#include "../distance_ops/canberra.cuh"
#include "../distance_ops/correlation.cuh"
#include "../distance_ops/cosine.cuh"
Expand Down
60 changes: 60 additions & 0 deletions cpp/src/distance/detail/distance_ops/bitwise_hamming.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include <cuda_runtime.h>
#include <stdint.h>

namespace cuvs::distance::detail::ops {

/**
* @brief the Bitwise Hamming distance matrix calculation
* It computes the following equation:
*
* c_ij = sum_k popcount(x_ik XOR y_kj)
*
* where x and y are binary data packed as uint8_t
*/
template <typename DataType, typename AccType, typename IdxType>
struct bitwise_hamming_distance_op {
using DataT = DataType;
using AccT = AccType;
using IdxT = IdxType;

IdxT k;

bitwise_hamming_distance_op(IdxT k_) noexcept : k(k_) {}

static constexpr bool use_norms = false;
static constexpr bool expensive_inner_loop = false;

template <typename Policy>
static constexpr size_t shared_mem_size()
{
return Policy::SmemSize;
}

__device__ __forceinline__ void core(AccT& acc, DataT& x, DataT& y) const
{
static_assert(std::is_same_v<DataT, uint8_t>, "BitwiseHamming only supports uint8_t");
// Ensure proper masking and casting to avoid undefined behavior
uint32_t xor_val = static_cast<uint32_t>(static_cast<uint8_t>(x ^ y));
uint32_t masked_val = xor_val & 0xffu;
int popcount = __popc(masked_val);
Comment on lines +45 to +46
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works, but it is inefficient to use __popc one byte at a time. Could you create an issue to note that there is a potential improvement if we use 4 byte storage type for x and y?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah the reason I did this was to avoid separate compilation paths for larger datatypes such as half and uint32. The idea was to dispatch to the corresponding type (uint32 or half) depending on the dimension (whether it is divisible by 4 or by 2 or not). But that would bloat the binary size because the distance op would have to be defined for these additional types.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we expect use cases where padding dims to be divisible by 32 would be a problem? But this discussion need not hold this PR, I have created #1613 to continue this discussion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That actually makes a lot of sense. Padding dims to be divisible by 32 and using uint32_t seems like the more optimal approach.

acc += static_cast<AccT>(popcount);
}

template <typename Policy>
__device__ __forceinline__ void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh],
AccT* regxn,
AccT* regyn,
IdxT gridStrideX,
IdxT gridStrideY) const
{
}
};

} // namespace cuvs::distance::detail::ops
30 changes: 23 additions & 7 deletions cpp/src/distance/detail/fused_distance_nn.cuh
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include "distance_ops/l2_exp.cuh" // ops::l2_exp_distance_op
#include "fused_distance_nn/cutlass_base.cuh"
#include "fused_distance_nn/fused_bitwise_hamming_nn.cuh"
#include "fused_distance_nn/fused_cosine_nn.cuh"
#include "fused_distance_nn/fused_l2_nn.cuh"
#include "fused_distance_nn/helper_structs.cuh"
Expand Down Expand Up @@ -68,16 +69,31 @@ void fusedDistanceNNImpl(OutT* min,

switch (metric) {
case cuvs::distance::DistanceType::CosineExpanded:
fusedCosineNN<DataT, OutT, IdxT, P, ReduceOpT, KVPReduceOpT>(
min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, stream);
if constexpr (std::is_same_v<DataT, uint8_t> || std::is_same_v<DataT, int8_t>) {
RAFT_FAIL("Cosine distance is not supported for uint8_t/int8_t data types");
} else {
fusedCosineNN<DataT, OutT, IdxT, P, ReduceOpT, KVPReduceOpT>(
min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, stream);
}
break;
case cuvs::distance::DistanceType::L2SqrtExpanded:
case cuvs::distance::DistanceType::L2Expanded:
// initOutBuffer is take care by fusedDistanceNNImpl() so we set it false to fusedL2NNImpl.
fusedL2NNImpl<DataT, OutT, IdxT, P, ReduceOpT, KVPReduceOpT>(
min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, false, stream);
if constexpr (std::is_same_v<DataT, uint8_t> || std::is_same_v<DataT, int8_t>) {
RAFT_FAIL("L2 distance is not supported for uint8_t/int8_t data types");
} else {
fusedL2NNImpl<DataT, OutT, IdxT, P, ReduceOpT, KVPReduceOpT>(
min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, false, stream);
}
break;
default: assert("only cosine/l2 metric is supported with fusedDistanceNN\n"); break;
case cuvs::distance::DistanceType::BitwiseHamming:
if constexpr (std::is_same_v<DataT, uint8_t>) {
fusedBitwiseHammingNN<DataT, OutT, IdxT, P, ReduceOpT, KVPReduceOpT>(
min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, stream);
} else {
RAFT_FAIL("BitwiseHamming distance only supports uint8_t data type");
}
break;
default: RAFT_FAIL("only cosine/l2/bitwise hamming metric is supported with fusedDistanceNN");
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once

#include "../distance_ops/bitwise_hamming.cuh" // ops::bitwise_hamming_distance_op
#include "../pairwise_distance_base.cuh" // PairwiseDistances
#include "helper_structs.cuh"
#include "simt_kernel.cuh"

namespace cuvs {
namespace distance {
namespace detail {

/**
* @brief Fused BitwiseHamming distance and 1-nearest-neighbor
*
* This implementation is only meaningful for uint8_t data type.
* The if constexpr in fusedDistanceNNImpl ensures it's only called for uint8_t.
*/
template <typename DataT,
typename OutT,
typename IdxT,
typename Policy,
typename ReduceOpT,
typename KVPReduceOpT>
void fusedBitwiseHammingNN(OutT* min,
const DataT* x,
const DataT* y,
const DataT* xn,
const DataT* yn,
IdxT m,
IdxT n,
IdxT k,
int* workspace,
ReduceOpT redOp,
KVPReduceOpT pairRedOp,
bool sqrt,
cudaStream_t stream)
{
typedef Policy P;

dim3 blk(P::Nthreads);
constexpr auto maxVal = std::numeric_limits<DataT>::max();
using kv_pair_type = raft::KeyValuePair<IdxT, uint32_t>;
using distance_op_type = ops::bitwise_hamming_distance_op<DataT, uint32_t, IdxT>;
distance_op_type distance_op{k};
auto kernel = fusedDistanceNNkernel<DataT,
kv_pair_type,
IdxT,
P,
ReduceOpT,
KVPReduceOpT,
distance_op_type,
raft::identity_op>;

constexpr size_t shmemSize = P::SmemSize;

dim3 grid = launchConfigGenerator<P>(m, n, shmemSize, kernel);

kernel<<<grid, blk, shmemSize, stream>>>(min,
x,
y,
nullptr,
nullptr,
m,
n,
k,
maxVal,
workspace,
redOp,
pairRedOp,
distance_op,
raft::identity_op{});

RAFT_CUDA_TRY(cudaGetLastError());
}

} // namespace detail
} // namespace distance
} // namespace cuvs
25 changes: 19 additions & 6 deletions cpp/src/distance/detail/fused_distance_nn/helper_structs.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2021-2024, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

Expand Down Expand Up @@ -27,8 +27,15 @@ namespace detail {
template <typename LabelT, typename DataT>
struct KVPMinReduceImpl {
typedef raft::KeyValuePair<LabelT, DataT> KVP;
DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { return b.value < a.value ? b : a; }
DI KVP operator()(const KVP& a, const KVP& b) { return b.value < a.value ? b : a; }
// Use index as tiebreaker for consistent behavior when distances are equal
DI KVP operator()(LabelT rit, const KVP& a, const KVP& b)
{
return (b.value < a.value || (b.value == a.value && b.key < a.key)) ? b : a;
}
DI KVP operator()(const KVP& a, const KVP& b)
{
return (b.value < a.value || (b.value == a.value && b.key < a.key)) ? b : a;
}

}; // KVPMinReduce

Expand All @@ -38,14 +45,16 @@ struct MinAndDistanceReduceOpImpl {

DI void operator()(LabelT rid, KVP* out, const KVP& other) const
{
if (other.value < out->value) {
// Use index as tiebreaker for consistent behavior when distances are equal
if (other.value < out->value || (other.value == out->value && other.key < out->key)) {
out->key = other.key;
out->value = other.value;
}
}
DI void operator()(LabelT rid, volatile KVP* out, const KVP& other) const
{
if (other.value < out->value) {
// Use index as tiebreaker for consistent behavior when distances are equal
if (other.value < out->value || (other.value == out->value && other.key < out->key)) {
out->key = other.key;
out->value = other.value;
}
Expand Down Expand Up @@ -123,7 +132,11 @@ struct kvp_cg_min_reduce_op {
using AccTypeT = AccType;
using IndexT = Index;
// functor signature.
__host__ __device__ KVP operator()(KVP a, KVP b) const { return a.value < b.value ? a : b; }
// Use index as tiebreaker for consistent behavior when distances are equal
__host__ __device__ KVP operator()(KVP a, KVP b) const
{
return (a.value < b.value || (a.value == b.value && a.key < b.key)) ? a : b;
}

__host__ __device__ AccType operator()(AccType a, AccType b) const { return min(a, b); }

Expand Down
Loading
Loading