Skip to content

Commit 863d2b7

Browse files
authored
fix a slew of unit errors (#232)
1 parent 0824413 commit 863d2b7

13 files changed

Lines changed: 136 additions & 29 deletions

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ Mmap = "a63ad114-7e13-5084-954f-fe012c677804"
2020
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
2121
Tensors = "48a634ad-e948-5137-8d70-aa71f2a747f4"
2222
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
23+
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
2324

2425
[targets]
25-
test = ["CellListMap", "LinearAlgebra", "Mmap", "Tensors", "Test", "StableRNGs"]
26+
test = ["CellListMap", "LinearAlgebra", "Mmap", "Tensors", "Test", "StableRNGs", "Unitful"]

src/NearestNeighbors.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ function check_input(::NNTree{V1}, m::AbstractMatrix) where {V1}
4444
end
4545
end
4646

47-
get_T(::Type{T}) where {T <: AbstractFloat} = T
48-
get_T(::T) where {T} = Float64
47+
get_T(::Type{T}) where {T} = typeof(float(zero(T)))
4948

5049
get_tree(tree::NNTree) = tree
5150

@@ -63,6 +62,13 @@ include("datafreetree.jl")
6362
include("knn.jl")
6463
include("inrange.jl")
6564

65+
# Type for internal distance calculations (before eval_end)
66+
dist_type_internal(tree::NNTree{V}) where V = get_T(eltype(V))
67+
dist_type_internal(tree::KDTree{V}) where V = typeof(eval_pow(tree.metric, zero(get_T(eltype(V)))))
68+
69+
# Get the "infinity" value in the correct distance space for the tree
70+
dist_typemax(tree::NNTree{V}) where V = typemax(dist_type_internal(tree))
71+
6672
for dim in (2, 3)
6773
for Tree in (KDTree, BallTree)
6874
tree = Tree(rand(dim, 10))

src/ball_tree.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ function _knn(tree::BallTree,
166166
point::AbstractVector,
167167
best_idxs::Union{Integer, AbstractVector{<:Integer}},
168168
best_dists::Union{Number, AbstractVector},
169+
::Union{Nothing, AbstractVector},
169170
skip::F) where {F}
170171
return knn_kernel!(tree, 1, point, best_idxs, best_dists, skip, nothing)
171172
end
@@ -321,6 +322,7 @@ end
321322
function _add_balltree_self_leaf_pairs!(results::Vector{NTuple{2,Int}}, tree::BallTree, leaf_idx::Int, other_leaf_idx::Int, r::Number, skip)
322323
point_range = get_leaf_range(tree.tree_data, leaf_idx)
323324
is_minkowski = tree.metric isa MinkowskiMetric
325+
r_cmp = is_minkowski ? eval_pow(tree.metric, r) : r
324326
if leaf_idx == other_leaf_idx
325327
@inbounds for i in point_range
326328
idx_i = tree.indices[i]
@@ -331,7 +333,7 @@ function _add_balltree_self_leaf_pairs!(results::Vector{NTuple{2,Int}}, tree::Ba
331333
if skip(idx_j)
332334
continue
333335
end
334-
if evaluate_maybe_end(tree.metric, point_i, tree.data[tree.reordered ? j : tree.indices[j]], !is_minkowski) <= r
336+
if evaluate_maybe_end(tree.metric, point_i, tree.data[tree.reordered ? j : tree.indices[j]], !is_minkowski) <= r_cmp
335337
a, b = idx_i < idx_j ? (idx_i, idx_j) : (idx_j, idx_i)
336338
push!(results, (a, b))
337339
end
@@ -348,7 +350,7 @@ function _add_balltree_self_leaf_pairs!(results::Vector{NTuple{2,Int}}, tree::Ba
348350
if skip(idx_j)
349351
continue
350352
end
351-
if evaluate_maybe_end(tree.metric, point_i, tree.data[tree.reordered ? j : tree.indices[j]], !is_minkowski) <= r
353+
if evaluate_maybe_end(tree.metric, point_i, tree.data[tree.reordered ? j : tree.indices[j]], !is_minkowski) <= r_cmp
352354
a, b = idx_i < idx_j ? (idx_i, idx_j) : (idx_j, idx_i)
353355
push!(results, (a, b))
354356
end
@@ -408,8 +410,7 @@ end
408410
function _inrange_pairs(tree::BallTree{V}, radius::Number, sortres, skip::F) where {V, F}
409411
isempty(tree.data) && return NTuple{2,Int}[]
410412
pairs = NTuple{2,Int}[]
411-
r = tree.metric isa MinkowskiMetric ? eval_pow(tree.metric, radius) : radius
412-
_inrange_balltree_self!(pairs, tree, 1, 1, r, skip)
413+
_inrange_balltree_self!(pairs, tree, 1, 1, radius, skip)
413414
sortres && sort!(pairs)
414415
return pairs
415416
end

src/brute_tree.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ function _knn(tree::BruteTree{V},
4545
point::AbstractVector,
4646
best_idxs::Union{Integer, AbstractVector{<:Integer}},
4747
best_dists::Union{Number, AbstractVector},
48+
::Union{Nothing, AbstractVector},
4849
skip::F) where {V, F}
4950

5051
return knn_kernel!(tree, point, best_idxs, best_dists, skip, nothing)

src/hyperrectangles.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ end
2525
@inline distance_function_min(vald, maxd, mind) = max(zero(eltype(vald)), max(mind - vald, vald - maxd))
2626

2727
function get_min_max_distance_no_end(f::Function, m::Metric, rec::HyperRectangle, point::AbstractVector{T}) where {T}
28-
s = zero(T)
2928
p = Distances.parameters(m)
29+
s = p === nothing ? eval_op(m, zero(T), zero(T)) : eval_op(m, zero(T), zero(T), p[1])
3030
@inbounds @simd for dim in eachindex(point)
3131
v = f(point[dim], rec.maxes[dim], rec.mins[dim])
3232
v_op = p === nothing ? eval_op(m, v, zero(T)) : eval_op(m, v, zero(T), p[dim])
@@ -58,8 +58,9 @@ end
5858
function get_min_max_distance(m::Metric, r1::HyperRectangle{V}, r2::HyperRectangle{V}) where {V}
5959
p = Distances.parameters(m)
6060
T = eltype(V)
61-
min_acc = zero(T)
62-
max_acc = zero(T)
61+
zero_op = p === nothing ? eval_op(m, zero(T), zero(T)) : eval_op(m, zero(T), zero(T), p[1])
62+
min_acc = zero_op
63+
max_acc = zero_op
6364
@inbounds for dim in eachindex(r1.mins)
6465
lo1 = r1.mins[dim]; hi1 = r1.maxes[dim]
6566
lo2 = r2.mins[dim]; hi2 = r2.maxes[dim]
@@ -82,7 +83,9 @@ end
8283
# Compute per-dimension contributions for max distance
8384
function get_max_distance_contributions(m::Metric, rec::HyperRectangle{V}, point::AbstractVector{T}) where {V,T}
8485
p = Distances.parameters(m)
85-
return V(
86+
sample_op = p === nothing ? eval_op(m, zero(T), zero(T)) : eval_op(m, zero(T), zero(T), p[1])
87+
ResultV = SVector{length(V), typeof(sample_op)}
88+
return ResultV(
8689
@inbounds begin
8790
v = distance_function_max(point[dim], rec.maxes[dim], rec.mins[dim])
8891
p === nothing ? eval_op(m, v, zero(T)) : eval_op(m, v, zero(T), p[dim])

src/hyperspheres.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
const NormMetric = Union{Euclidean,Chebyshev,Cityblock,Minkowski,WeightedEuclidean,WeightedCityblock,WeightedMinkowski,Mahalanobis}
22

3-
struct HyperSphere{N,T <: AbstractFloat}
3+
struct HyperSphere{N,T}
44
center::SVector{N,T}
55
r::T
66
end
@@ -72,14 +72,14 @@ function create_bsphere(data::AbstractVector{V}, metric::Metric, indices::Vector
7272
T = get_T(eltype(V))
7373
center = sum(data[indices[r]] for r in range) * (one(T) / length(range))
7474
r = maximum(evaluate(metric, data[indices[i]], center) for i in range)
75-
r += eps(T)
75+
r += eps(r)
7676
return HyperSphere(center, r)
7777
end
7878

7979
# Creates a bounding sphere from two other spheres
8080
function create_bsphere(m::Metric,
8181
s1::HyperSphere{N,T},
82-
s2::HyperSphere{N,T}) where {N, T <: AbstractFloat}
82+
s2::HyperSphere{N,T}) where {N, T}
8383
if encloses(m, s1, s2)
8484
return HyperSphere(s2.center, s2.r)
8585
elseif encloses(m, s2, s1)

src/inrange.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
check_radius(r) = r < 0 && throw(ArgumentError("the query radius r must be ≧ 0"))
1+
check_radius(r) = r < zero(r) && throw(ArgumentError("the query radius r must be ≧ 0"))
22

33
"""
44
inrange(tree::NNTree, points, radius) -> indices

src/kd_tree.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,14 +161,15 @@ function _knn(tree::KDTree,
161161
point::AbstractVector,
162162
best_idxs::Union{Integer, AbstractVector{<:Integer}},
163163
best_dists::Union{Number, AbstractVector},
164+
best_dists_final::Union{Nothing, AbstractVector},
164165
skip::F) where {F}
165166
init_min = get_min_distance_no_end(tree.metric, tree.hyper_rec, point)
166167
best_idxs, best_dists = knn_kernel!(tree, 1, point, best_idxs, best_dists, init_min, tree.hyper_rec, skip, nothing)
167168
best_dists isa Number && return best_idxs, eval_end(tree.metric, best_dists)
168169
@simd for i in eachindex(best_dists)
169-
@inbounds best_dists[i] = eval_end(tree.metric, best_dists[i])
170+
@inbounds best_dists_final[i] = eval_end(tree.metric, best_dists[i])
170171
end
171-
return best_idxs, best_dists
172+
return best_idxs, best_dists_final
172173
end
173174

174175
function knn_kernel!(tree::KDTree{V},

src/knn.jl

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,27 @@ end
3838
knn_point!(tree::NNTree{V}, point::AbstractVector{T}, sortres, dist, idx, skip::F) where {V, T <: Number, F} =
3939
_knn_point!(tree, point, sortres, dist, idx, skip)
4040

41-
function _knn_point!(tree::NNTree{V}, point::AbstractVector{T}, sortres, dist, idx, skip::F) where {V, T <: Number, F}
41+
function _knn_point!(tree::NNTree{V}, point::AbstractVector{T}, sortres, dist_final, idx, skip::F) where {V, T <: Number, F}
4242
fill!(idx, -1)
43-
fill!(dist, typemax(get_T(eltype(V))))
44-
_knn(tree, point, idx, dist, skip)
43+
inner_tree = get_tree(tree)
44+
45+
T_internal = dist_type_internal(inner_tree)
46+
T_final = eltype(dist_final)
47+
if T_internal === T_final
48+
dist_internal = dist_final
49+
else
50+
dist_internal = Vector{T_internal}(undef, length(dist_final))
51+
end
52+
fill!(dist_internal, dist_typemax(inner_tree))
53+
54+
_knn(tree, point, idx, dist_internal, dist_final, skip)
55+
4556
if skip !== Returns(false)
4657
skipped_idxs = findall(==(-1), idx)
4758
deleteat!(idx, skipped_idxs)
48-
deleteat!(dist, skipped_idxs)
59+
deleteat!(dist_final, skipped_idxs)
4960
end
50-
sortres && heap_sort_inplace!(dist, idx)
51-
inner_tree = get_tree(tree)
61+
sortres && heap_sort_inplace!(dist_final, idx)
5262
if inner_tree.reordered
5363
for j in eachindex(idx)
5464
@inbounds idx[j] = inner_tree.indices[idx[j]]
@@ -78,6 +88,8 @@ function knn!(idxs::AbstractVector{<:Integer}, dists::AbstractVector, tree::NNTr
7888
check_k(tree, k)
7989
length(idxs) == k || throw(ArgumentError("idxs must be of length k"))
8090
length(dists) == k || throw(ArgumentError("dists must be of length k"))
91+
expected_dist_type = get_T(eltype(V))
92+
eltype(dists) === expected_dist_type || throw(ArgumentError("dists must have eltype $expected_dist_type, got $(eltype(dists))"))
8193
knn_point!(tree, point, sortres, dists, idxs, skip)
8294
return idxs, dists
8395
end
@@ -126,7 +138,7 @@ See also: `knn`.
126138
function nn(tree::NNTree{V}, point::AbstractVector{T}, skip::F=Returns(false)) where {V, T <: Number, F <: Function}
127139
check_for_nan_in_points(point)
128140
check_k(tree, 1)
129-
best_idx, best_dist = _knn(tree, point, -1, typemax(get_T(eltype(V))), skip)
141+
best_idx, best_dist = _knn(tree, point, -1, dist_typemax(get_tree(tree)), nothing, skip)
130142
inner_tree = get_tree(tree)
131143
final_idx = inner_tree.reordered ? inner_tree.indices[best_idx] : best_idx
132144
return final_idx, best_dist

src/periodic_tree.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ struct PeriodicTree{V<:AbstractVector, M, Tree <: NNTree{V, M}, D, W} <: NNTree{
7373
# Check for valid box dimensions (finite dimensions must be positive)
7474
for i in 1:dim
7575
actual_width = maxs_vec[i] - mins_vec[i]
76-
if isfinite(actual_width) && actual_width <= 0
76+
if isfinite(actual_width) && actual_width <= zero(actual_width)
7777
throw(ArgumentError("Box width in dimension $i must be positive, got $actual_width"))
7878
end
7979
end
@@ -89,7 +89,7 @@ struct PeriodicTree{V<:AbstractVector, M, Tree <: NNTree{V, M}, D, W} <: NNTree{
8989
end
9090

9191
# Find periodic dimensions (those with non-zero box widths)
92-
periodic_dims = findall(>(0), box_widths)
92+
periodic_dims = findall(w -> w > zero(w), box_widths)
9393
n_periodic = length(periodic_dims)
9494

9595
# Generate combinations only for periodic dimensions
@@ -179,6 +179,7 @@ function _knn(tree::PeriodicTree{V,M},
179179
point::AbstractVector,
180180
best_idxs::Union{Integer, AbstractVector{<:Integer}},
181181
best_dists::Union{Number, AbstractVector},
182+
best_dists_final::Union{Nothing, AbstractVector},
182183
skip::F) where {V, M, F}
183184

184185
dedup_state = empty!(tree.dedup_set)
@@ -222,12 +223,16 @@ function _knn(tree::PeriodicTree{V,M},
222223
best_dists = eval_end(tree.tree.metric, best_dists)
223224
else
224225
@simd for i in eachindex(best_dists)
225-
@inbounds best_dists[i] = eval_end(tree.tree.metric, best_dists[i])
226+
@inbounds best_dists_final[i] = eval_end(tree.tree.metric, best_dists[i])
226227
end
227228
end
228229
end
229230
empty!(dedup_state)
230-
return best_idxs, best_dists
231+
if best_dists isa Number
232+
return best_idxs, best_dists
233+
else
234+
return best_idxs, best_dists_final
235+
end
231236
end
232237

233238
function _inrange(tree::PeriodicTree{V},

0 commit comments

Comments
 (0)