Skip to content

Commit 143b1ea

Browse files
improve getindex error message (#1367)
closes #994 --------- Co-authored-by: Penelope Yong <penelopeysm@gmail.com>
1 parent b854e0d commit 143b1ea

7 files changed

Lines changed: 54 additions & 13 deletions

File tree

HISTORY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# 0.41.7
2+
3+
Accessing a nonexistent variable in a `VarNamedTuple` now throws a `KeyError` with the original `VarName`, instead of an opaque `type NamedTuple has no field ...` error.
4+
15
# 0.41.6
26

37
Add a `factorize::Bool` keyword argument for `pointwise_logdensities(model, values)`, which controls whether pointwise logdensities for factorisable distributions (e.g. `MvNormal`, `product_distribution`, etc.) are returned as a single log-density for the whole distribution, or as an array of log-densities for each factor.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.41.6"
3+
version = "0.41.7"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/varnamedtuple.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ function AbstractPPL.hasvalue(vnt::VarNamedTuple, vn::VarName, dist::LKJCholesky
226226
for k in keys(val)
227227
# VarNamedTuples have VarNames as keys, PartialArrays have Index optics.
228228
subvn = val isa VarNamedTuple ? prefix(k, vn) : AbstractPPL.append_optic(vn, k)
229-
dval[subvn] = _getindex_optic(val, k)
229+
dval[subvn] = _getindex_optic(val, k, subvn)
230230
end
231231
return AbstractPPL.hasvalue(dval, vn, dist)
232232
end
@@ -244,7 +244,7 @@ function AbstractPPL.getvalue(vnt::VarNamedTuple, vn::VarName, dist::LKJCholesky
244244
for k in keys(val)
245245
# VarNamedTuples have VarNames as keys, PartialArrays have Index optics.
246246
subvn = val isa VarNamedTuple ? prefix(k, vn) : AbstractPPL.append_optic(vn, k)
247-
dval[subvn] = _getindex_optic(val, k)
247+
dval[subvn] = _getindex_optic(val, k, subvn)
248248
end
249249
return AbstractPPL.getvalue(dval, vn, dist)
250250
end

src/varnamedtuple/getset.jl

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ const IndexWithoutChild = AbstractPPL.Index{<:Tuple,<:NamedTuple,AbstractPPL.Ide
1111
_unimplemented() = error("Not implemented")
1212

1313
"""
14-
DynamicPPL._getindex_optic(collection, optic::AbstractPPL.Optic)
14+
DynamicPPL._getindex_optic(collection, optic::AbstractPPL.Optic, orig_vn::VarName)
1515
DynamicPPL._getindex_optic(collection, vn::VarName)
1616
1717
Access the value in `collection` at the location specified by the given `optic`. If a `VarName`
@@ -27,16 +27,28 @@ Note that it is only valid to index into a `VarNamedTuple` with a `Property` opt
2727
`PartialArray` with an `Index` optic. Other combinations are not valid. When we have reached
2828
the leaf of the VNT i.e. a value, we could still handle pure `Index` optics if the value is
2929
an `AbstractArray`, but otherwise the only valid optic is `Iden`.
30+
31+
`orig_vn` is used to keep track of the original VarName used to index into a VarNamedTuple,
32+
and is only for error reporting purposes.
3033
"""
3134
function _getindex_optic(vnt::VarNamedTuple, vn::VarName)
32-
return _getindex_optic(vnt, AbstractPPL.varname_to_optic(vn))
35+
return _getindex_optic(vnt, AbstractPPL.varname_to_optic(vn), vn)
36+
end
37+
function _getindex_optic(vnt::VarNamedTuple, vn::VarName, orig_vn)
38+
return _getindex_optic(vnt, AbstractPPL.varname_to_optic(vn), orig_vn)
3339
end
34-
@inline _getindex_optic(@nospecialize(x::Any), ::AbstractPPL.Iden) = x
35-
@inline _getindex_optic(x::Any, o::AbstractPPL.AbstractOptic) = o(x)
36-
function _getindex_optic(vnt::VarNamedTuple, optic::AbstractPPL.Property{S}) where {S}
37-
return _getindex_optic(getindex(vnt.data, S), optic.child)
40+
41+
@inline _getindex_optic(@nospecialize(x::Any), ::AbstractPPL.Iden, orig_vn) = x
42+
@inline _getindex_optic(x::Any, o::AbstractPPL.AbstractOptic, orig_vn) = o(x)
43+
function _getindex_optic(
44+
vnt::VarNamedTuple, optic::AbstractPPL.Property{S}, orig_vn
45+
) where {S}
46+
if !haskey(vnt.data, S)
47+
throw(KeyError(orig_vn))
48+
end
49+
return _getindex_optic(getindex(vnt.data, S), optic.child, orig_vn)
3850
end
39-
function _getindex_optic(pa::PartialArray, optic::AbstractPPL.Index)
51+
function _getindex_optic(pa::PartialArray, optic::AbstractPPL.Index, orig_vn)
4052
coptic = AbstractPPL.concretize_top_level(optic, pa.data)
4153
child_value =
4254
if _is_multiindex(pa, coptic.ix...; coptic.kw...) &&
@@ -49,9 +61,9 @@ function _getindex_optic(pa::PartialArray, optic::AbstractPPL.Index)
4961
else
5062
getindex(pa, coptic.ix...; coptic.kw...)
5163
end
52-
return _getindex_optic(child_value, optic.child)
64+
return _getindex_optic(child_value, optic.child, orig_vn)
5365
end
54-
function _getindex_optic(arr::AbstractArray, optic::IndexWithoutChild)
66+
function _getindex_optic(arr::AbstractArray, optic::IndexWithoutChild, orig_vn)
5567
coptic = AbstractPPL.concretize_top_level(optic, arr)
5668
return Base.getindex(arr, coptic.ix...; coptic.kw...)
5769
end

test/logdensityfunction.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ end
348348
@test_throws ArgumentError to_vector_params(vecvals, ldf)
349349

350350
accs = OnlyAccsVarInfo(VectorParamAccumulator(ldf))
351-
@test_throws ErrorException init!!(
351+
@test_throws KeyError init!!(
352352
extra_model, accs, InitFromPrior(), transform_strategy
353353
)
354354
end

test/varinfo.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,19 @@ end
8787
vi, TransformedValue(x, NoTransform()), Normal(), vn, x
8888
)
8989
@test !isempty(vi)
90+
91+
@testset "KeyError for missing varname" begin
92+
@model function test_model()
93+
x ~ Normal()
94+
return nothing
95+
end
96+
vi2 = VarInfo(test_model())
97+
# KeyError propagates from VarNamedTuple through VarInfo
98+
@test_throws KeyError DynamicPPL.getindex_internal(vi2, @varname(y))
99+
@test_throws KeyError DynamicPPL.get_transformed_value(vi2, @varname(y))
100+
# Direct VarNamedTuple access also throws KeyError
101+
@test_throws KeyError vi2.values[@varname(y)]
102+
end
90103
end
91104

92105
@testset "get/set/acclogp" begin

test/varnamedtuple.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,18 @@ Base.size(st::SizedThing) = st.size
346346
end
347347
end
348348

349+
@testset "KeyError for missing properties" begin
350+
vnt = @vnt begin
351+
x.a := 1.0
352+
end
353+
# Should throw KeyError for missing top-level symbol
354+
@test_throws KeyError vnt[@varname(y)]
355+
# Should throw KeyError for missing nested property
356+
@test_throws KeyError vnt[@varname(x.b)]
357+
# Sanity check: accessing existing property should work
358+
@test vnt[@varname(x.a)] == 1.0
359+
end
360+
349361
@testset "haskey on PartialArray" begin
350362
@testset "no ALBs" begin
351363
vnt = @vnt begin

0 commit comments

Comments
 (0)