-
Notifications
You must be signed in to change notification settings - Fork 150
[FEA] Binary IVF Flat Index #1099
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
tarang-jain
wants to merge
146
commits into
rapidsai:release/26.02
Choose a base branch
from
tarang-jain:binary-kmeans
base: release/26.02
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
146 commits
Select commit
Hold shift + click to select a range
bedcf4c
first commit
tarang-jain 09f9a22
Merge branch 'branch-25.08' of https://github.com/rapidsai/cuvs into …
tarang-jain 51836d4
index header
tarang-jain 23ef877
populate functions;ivf_list type;kmeans_predict
tarang-jain a7fce8e
Merge branch 'branch-25.08' of https://github.com/rapidsai/cuvs into …
tarang-jain 6a98a88
hamming_op
tarang-jain 76c9ee5
Merge branch 'branch-25.08' into binary-kmeans
tarang-jain 916a4cf
rm binary_ivf
tarang-jain 8ec4d59
Merge branch 'binary-kmeans' of https://github.com/tarang-jain/cuvs i…
tarang-jain 1941b2e
modify ivf_flat_build
tarang-jain cd00b83
rm binary_ivf_flat
tarang-jain 4cffe84
rm unused
tarang-jain 2bc9007
updates
tarang-jain 7803850
quantize
tarang-jain ff7be4a
cleanup
tarang-jain 3149192
pre-commit
tarang-jain dd1b0d4
update kmeans_predict
tarang-jain 2b9bef4
src kmeans
tarang-jain 6ec32d8
style
tarang-jain 5c59753
corrections to logic
tarang-jain 2271809
clang
tarang-jain 29388e9
testing
tarang-jain 4a491cd
correct tests
tarang-jain 4c27acd
binary_index flag
tarang-jain 7085405
correct mdspan,reduction_op
tarang-jain 2075bd0
Merge branch 'branch-25.08' into binary-kmeans
tarang-jain d06b4f5
correct checks
tarang-jain 13de163
Merge branch 'branch-25.08' of https://github.com/rapidsai/cuvs into …
tarang-jain e79d520
Merge branch 'binary-kmeans' of https://github.com/tarang-jain/cuvs i…
tarang-jain c7f3ade
Merge branch 'branch-25.10' of https://github.com/rapidsai/cuvs into …
tarang-jain e7fdd4c
Merge branch 'branch-25.10' into binary-kmeans
tarang-jain 3ce25af
fused-kernel
tarang-jain 1f565ef
Merge branch 'binary-kmeans' of https://github.com/tarang-jain/cuvs i…
tarang-jain 2c2a91d
cleeanup;syntax
tarang-jain b9a6a65
cleanup
tarang-jain 0f19648
temp-fix-compilation-errors
tarang-jain 2de4bcc
Merge branch 'branch-25.10' into binary-kmeans
tarang-jain f572e0b
Merge branch 'branch-25.10' into binary-kmeans
tarang-jain 66d8b94
uint32_t for storing dists
tarang-jain 46f9676
Merge branch 'branch-25.10' of https://github.com/rapidsai/cuvs into …
tarang-jain fa2193f
Merge branch 'binary-kmeans' of https://github.com/tarang-jain/cuvs i…
tarang-jain 8f18ce0
no use_norms check
tarang-jain c7084ac
coarse search
tarang-jain 1ca868b
debug; fix dim() error
tarang-jain 42859fb
Merge branch 'branch-25.10' into binary-kmeans
tarang-jain f0597fa
Merge branch 'branch-25.10' into binary-kmeans
tarang-jain ca97da6
Merge branch 'branch-25.10' into binary-kmeans
tarang-jain 8ad530b
Merge branch 'branch-25.10' of https://github.com/rapidsai/cuvs into …
tarang-jain 17449e4
debug
tarang-jain f8e6e86
Merge branch 'branch-25.10' of https://github.com/rapidsai/cuvs into …
tarang-jain 6aa2401
fix float underflow
tarang-jain d1717f5
rm new file
tarang-jain 3b284cd
Merge branch 'branch-25.10' of https://github.com/rapidsai/cuvs into …
tarang-jain f34103d
Merge branch 'branch-25.10' into binary-kmeans
tarang-jain 86a6593
Merge branch 'branch-25.10' into binary-kmeans
tarang-jain ef37c59
Merge branch 'branch-25.10' into binary-kmeans
tarang-jain 5e794fe
Merge branch 'branch-25.10' into binary-kmeans
tarang-jain 359d458
Merge branch 'branch-25.10' into binary-kmeans
tarang-jain d999591
Merge branch 'branch-25.10' into binary-kmeans
tarang-jain f242561
bug fixes
tarang-jain 5517ae3
Merge branch 'binary-kmeans' of https://github.com/tarang-jain/cuvs i…
tarang-jain 4aa2881
Merge branch 'branch-25.12' into binary-kmeans
tarang-jain b0d78b6
Merge branch 'branch-25.12' into binary-kmeans
tarang-jain 184f025
Merge branch 'branch-25.12' into binary-kmeans
tarang-jain 113e242
Merge branch 'branch-25.12' into binary-kmeans
tarang-jain 540db62
correct fused instantiation
tarang-jain d9eed14
Merge branch 'binary-kmeans' of https://github.com/tarang-jain/cuvs i…
tarang-jain 7ceed82
Merge branch 'branch-25.12' into binary-kmeans
tarang-jain ac261be
Merge branch 'branch-25.12' into binary-kmeans
tarang-jain b66b8ee
cleanup
tarang-jain 4a19ec1
Merge branch 'binary-kmeans' of https://github.com/tarang-jain/cuvs i…
tarang-jain 80d5f9c
Merge branch 'branch-25.12' of https://github.com/rapidsai/cuvs into …
tarang-jain d2b1184
Merge branch 'branch-25.12' of https://github.com/rapidsai/cuvs into …
tarang-jain ab1083e
Merge branch 'branch-25.12' into binary-kmeans
tarang-jain 9a26624
thorough equivalence testing checkpoint
tarang-jain 8a2098f
Merge branch 'binary-kmeans' of https://github.com/tarang-jain/cuvs i…
tarang-jain d4aae95
cleanup-1
tarang-jain f5c61dc
cleanup-2
tarang-jain 54ae1a2
cleanup-3
tarang-jain 76fc1f4
cleanup-4 (style
tarang-jain a8a3a0a
cleanup-6 tests
tarang-jain 05363f5
cleanup-7 (tests)
tarang-jain de3bcc1
cleanup-8 tests
tarang-jain a02e563
cleanup-8 (tests)
tarang-jain b6181a0
cleanup-8 (tests)
tarang-jain 5f640c1
use raft ci artifacts from pr
tarang-jain cc0fa18
wheels artifacts from raft pr
tarang-jain de606fa
style
tarang-jain 385d8c7
shell stlye check
tarang-jain ae49732
Remove RAFT Dask channel from conda packages script
tarang-jain 566b247
fix visible sections in simt_kernel
tarang-jain a24d13b
Merge branch 'branch-25.12' of https://github.com/rapidsai/cuvs into …
tarang-jain b4ea915
raft artifact for python build
tarang-jain 34d2ec5
Merge branch 'branch-25.12' of https://github.com/rapidsai/cuvs into …
tarang-jain 55b13d8
add skip flags
tarang-jain f44a7d8
Merge branch 'branch-25.12' of https://github.com/rapidsai/cuvs into …
tarang-jain 942dc96
Merge branch 'main' into binary-kmeans
tarang-jain 8d99bc7
set binary_index_ flag explicitly
tarang-jain 5153555
Merge branch 'main' into binary-kmeans
tarang-jain ca41dc8
Merge branch 'main' of https://github.com/rapidsai/cuvs into binary-k…
tarang-jain 182dadf
correct recall threshold
tarang-jain 99eea8a
merge upstream
tarang-jain 5220772
style and correct tests
tarang-jain 5fc6f4b
rm extra test
tarang-jain 37378d7
Merge branch 'main' into binary-kmeans
tarang-jain b948a33
rm test
tarang-jain c5e07b1
Merge branch 'binary-kmeans' of https://github.com/tarang-jain/cuvs i…
tarang-jain cbf83b8
Merge branch 'main' into binary-kmeans
tarang-jain c14f26f
Merge branch 'main' into binary-kmeans
tarang-jain b3eac6f
Merge branch 'main' into binary-kmeans
tarang-jain 03a01a1
Merge branch 'main' into binary-kmeans
achirkin 7b409c7
on the fly dataset expansion
tarang-jain 080a5ea
Merge branch 'main' into binary-kmeans
tarang-jain 107a3e4
address PR reviews
tarang-jain 1caabeb
Merge branch 'binary-kmeans' of https://github.com/tarang-jain/cuvs i…
tarang-jain 3681320
simplify bitwise_decode_op
tarang-jain 118a3d7
commit suggestion
tarang-jain 3399656
clang
tarang-jain d4fe3fe
Merge branch 'main' into binary-kmeans
tarang-jain 653364e
merge upstream
tarang-jain ef34908
address reviews
tarang-jain d12a211
style
tarang-jain 91c6734
undo ci changes
tarang-jain 42f1bb7
fix kmeans mapping op
tarang-jain e59a357
other fixes to kmeans for binary data
tarang-jain d13c0c9
fix compilation errors
tarang-jain 8162293
fix compilation
tarang-jain fcae247
fix compilation
tarang-jain a9599bd
doc
tarang-jain ffbcdcc
Merge branch 'main' of https://github.com/rapidsai/cuvs into binary-k…
tarang-jain 3ffba85
simplify ivf-flat build
tarang-jain c0a99e2
fix compilation errors
tarang-jain dbb6423
bug fixes
tarang-jain 248911c
debug
tarang-jain 0656ee6
more corrections to kmeans
tarang-jain 1423356
Merge branch 'main' into binary-kmeans
tarang-jain a25ddac
Merge branch 'main' into binary-kmeans
tarang-jain e8a8152
Merge branch 'main' of https://github.com/rapidsai/cuvs into binary-k…
tarang-jain 8feebb8
bug fixes
tarang-jain 997ddde
debug
tarang-jain 89b54a1
working impl;rm debug statements
tarang-jain 07354d1
rm debug prints:
tarang-jain 07e1837
Merge branch 'main' into binary-kmeans
tarang-jain 510bafd
Merge branch 'release/26.02' into binary-kmeans
tarang-jain 0ddce5d
Merge branch 'release/26.02' into binary-kmeans
tarang-jain d546471
Merge branch 'release/26.02' into binary-kmeans
tarang-jain File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <cuda_runtime.h> | ||
| #include <stdint.h> | ||
|
|
||
| namespace cuvs::distance::detail::ops { | ||
|
|
||
| /** | ||
| * @brief the Bitwise Hamming distance matrix calculation | ||
| * It computes the following equation: | ||
| * | ||
| * c_ij = sum_k popcount(x_ik XOR y_kj) | ||
| * | ||
| * where x and y are binary data packed as uint8_t | ||
| */ | ||
| template <typename DataType, typename AccType, typename IdxType> | ||
| struct bitwise_hamming_distance_op { | ||
| using DataT = DataType; | ||
| using AccT = AccType; | ||
| using IdxT = IdxType; | ||
|
|
||
| IdxT k; | ||
|
|
||
| bitwise_hamming_distance_op(IdxT k_) noexcept : k(k_) {} | ||
|
|
||
| static constexpr bool use_norms = false; | ||
| static constexpr bool expensive_inner_loop = false; | ||
|
|
||
| template <typename Policy> | ||
| static constexpr size_t shared_mem_size() | ||
| { | ||
| return Policy::SmemSize; | ||
| } | ||
|
|
||
| __device__ __forceinline__ void core(AccT& acc, DataT& x, DataT& y) const | ||
| { | ||
| static_assert(std::is_same_v<DataT, uint8_t>, "BitwiseHamming only supports uint8_t"); | ||
| // Ensure proper masking and casting to avoid undefined behavior | ||
| uint32_t xor_val = static_cast<uint32_t>(static_cast<uint8_t>(x ^ y)); | ||
| uint32_t masked_val = xor_val & 0xffu; | ||
| int popcount = __popc(masked_val); | ||
| acc += static_cast<AccT>(popcount); | ||
| } | ||
|
|
||
| template <typename Policy> | ||
| __device__ __forceinline__ void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], | ||
| AccT* regxn, | ||
| AccT* regyn, | ||
| IdxT gridStrideX, | ||
| IdxT gridStrideY) const | ||
| { | ||
| } | ||
| }; | ||
|
|
||
| } // namespace cuvs::distance::detail::ops | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
82 changes: 82 additions & 0 deletions
82
cpp/src/distance/detail/fused_distance_nn/fused_bitwise_hamming_nn.cuh
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,82 @@ | ||
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
| #pragma once | ||
|
|
||
| #include "../distance_ops/bitwise_hamming.cuh" // ops::bitwise_hamming_distance_op | ||
| #include "../pairwise_distance_base.cuh" // PairwiseDistances | ||
| #include "helper_structs.cuh" | ||
| #include "simt_kernel.cuh" | ||
|
|
||
| namespace cuvs { | ||
| namespace distance { | ||
| namespace detail { | ||
|
|
||
| /** | ||
| * @brief Fused BitwiseHamming distance and 1-nearest-neighbor | ||
| * | ||
| * This implementation is only meaningful for uint8_t data type. | ||
| * The if constexpr in fusedDistanceNNImpl ensures it's only called for uint8_t. | ||
| */ | ||
| template <typename DataT, | ||
| typename OutT, | ||
| typename IdxT, | ||
| typename Policy, | ||
| typename ReduceOpT, | ||
| typename KVPReduceOpT> | ||
| void fusedBitwiseHammingNN(OutT* min, | ||
| const DataT* x, | ||
| const DataT* y, | ||
| const DataT* xn, | ||
| const DataT* yn, | ||
| IdxT m, | ||
| IdxT n, | ||
| IdxT k, | ||
| int* workspace, | ||
| ReduceOpT redOp, | ||
| KVPReduceOpT pairRedOp, | ||
| bool sqrt, | ||
| cudaStream_t stream) | ||
| { | ||
| typedef Policy P; | ||
|
|
||
| dim3 blk(P::Nthreads); | ||
| constexpr auto maxVal = std::numeric_limits<DataT>::max(); | ||
| using kv_pair_type = raft::KeyValuePair<IdxT, uint32_t>; | ||
| using distance_op_type = ops::bitwise_hamming_distance_op<DataT, uint32_t, IdxT>; | ||
| distance_op_type distance_op{k}; | ||
| auto kernel = fusedDistanceNNkernel<DataT, | ||
| kv_pair_type, | ||
| IdxT, | ||
| P, | ||
| ReduceOpT, | ||
| KVPReduceOpT, | ||
| distance_op_type, | ||
| raft::identity_op>; | ||
|
|
||
| constexpr size_t shmemSize = P::SmemSize; | ||
|
|
||
| dim3 grid = launchConfigGenerator<P>(m, n, shmemSize, kernel); | ||
|
|
||
| kernel<<<grid, blk, shmemSize, stream>>>(min, | ||
| x, | ||
| y, | ||
| nullptr, | ||
| nullptr, | ||
| m, | ||
| n, | ||
| k, | ||
| maxVal, | ||
| workspace, | ||
| redOp, | ||
| pairRedOp, | ||
| distance_op, | ||
| raft::identity_op{}); | ||
|
|
||
| RAFT_CUDA_TRY(cudaGetLastError()); | ||
| } | ||
|
|
||
| } // namespace detail | ||
| } // namespace distance | ||
| } // namespace cuvs |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This works, but it is inefficient to use
__popcone byte at a time. Could you create an issue to note that there is a potential improvement if we use 4 byte storage type for x and y?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah the reason I did this was to avoid separate compilation paths for larger datatypes such as half and uint32. The idea was to dispatch to the corresponding type (uint32 or half) depending on the dimension (whether it is divisible by 4 or by 2 or not). But that would bloat the binary size because the distance op would have to be defined for these additional types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we expect use cases where padding dims to be divisible by 32 would be a problem? But this discussion need not hold this PR, I have created #1613 to continue this discussion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That actually makes a lot of sense. Padding dims to be divisible by 32 and using uint32_t seems like the more optimal approach.