|
38 | 38 | knn_point!(tree::NNTree{V}, point::AbstractVector{T}, sortres, dist, idx, skip::F) where {V, T <: Number, F} = |
39 | 39 | _knn_point!(tree, point, sortres, dist, idx, skip) |
40 | 40 |
|
41 | | -function _knn_point!(tree::NNTree{V}, point::AbstractVector{T}, sortres, dist, idx, skip::F) where {V, T <: Number, F} |
| 41 | +function _knn_point!(tree::NNTree{V}, point::AbstractVector{T}, sortres, dist_final, idx, skip::F) where {V, T <: Number, F} |
42 | 42 | fill!(idx, -1) |
43 | | - fill!(dist, typemax(get_T(eltype(V)))) |
44 | | - _knn(tree, point, idx, dist, skip) |
| 43 | + inner_tree = get_tree(tree) |
| 44 | + |
| 45 | + T_internal = dist_type_internal(inner_tree) |
| 46 | + T_final = eltype(dist_final) |
| 47 | + if T_internal === T_final |
| 48 | + dist_internal = dist_final |
| 49 | + else |
| 50 | + dist_internal = Vector{T_internal}(undef, length(dist_final)) |
| 51 | + end |
| 52 | + fill!(dist_internal, dist_typemax(inner_tree)) |
| 53 | + |
| 54 | + _knn(tree, point, idx, dist_internal, dist_final, skip) |
| 55 | + |
45 | 56 | if skip !== Returns(false) |
46 | 57 | skipped_idxs = findall(==(-1), idx) |
47 | 58 | deleteat!(idx, skipped_idxs) |
48 | | - deleteat!(dist, skipped_idxs) |
| 59 | + deleteat!(dist_final, skipped_idxs) |
49 | 60 | end |
50 | | - sortres && heap_sort_inplace!(dist, idx) |
51 | | - inner_tree = get_tree(tree) |
| 61 | + sortres && heap_sort_inplace!(dist_final, idx) |
52 | 62 | if inner_tree.reordered |
53 | 63 | for j in eachindex(idx) |
54 | 64 | @inbounds idx[j] = inner_tree.indices[idx[j]] |
@@ -78,6 +88,8 @@ function knn!(idxs::AbstractVector{<:Integer}, dists::AbstractVector, tree::NNTr |
78 | 88 | check_k(tree, k) |
79 | 89 | length(idxs) == k || throw(ArgumentError("idxs must be of length k")) |
80 | 90 | length(dists) == k || throw(ArgumentError("dists must be of length k")) |
| 91 | + expected_dist_type = get_T(eltype(V)) |
| 92 | + eltype(dists) === expected_dist_type || throw(ArgumentError("dists must have eltype $expected_dist_type, got $(eltype(dists))")) |
81 | 93 | knn_point!(tree, point, sortres, dists, idxs, skip) |
82 | 94 | return idxs, dists |
83 | 95 | end |
@@ -126,7 +138,7 @@ See also: `knn`. |
126 | 138 | function nn(tree::NNTree{V}, point::AbstractVector{T}, skip::F=Returns(false)) where {V, T <: Number, F <: Function} |
127 | 139 | check_for_nan_in_points(point) |
128 | 140 | check_k(tree, 1) |
129 | | - best_idx, best_dist = _knn(tree, point, -1, typemax(get_T(eltype(V))), skip) |
| 141 | + best_idx, best_dist = _knn(tree, point, -1, dist_typemax(get_tree(tree)), nothing, skip) |
130 | 142 | inner_tree = get_tree(tree) |
131 | 143 | final_idx = inner_tree.reordered ? inner_tree.indices[best_idx] : best_idx |
132 | 144 | return final_idx, best_dist |
|
0 commit comments