Add ComponentArraysExt to fix mixed index/property access on ComponentArrays (#1230)#1373
Conversation
Fixes TuringLang#1230 ComponentArray variables that mix index-based (x[1]) and property-based (x.a) access in the same model would crash with a MethodError in _setindex_optic!!. Added ext/DynamicPPLComponentArraysExt.jl with two overloads: - make_leaf for ComponentArray templates - _setindex_optic!! for PartialArray{ComponentArray} + Property optic
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #1373 +/- ##
==========================================
+ Coverage 82.28% 82.35% +0.07%
==========================================
Files 49 50 +1
Lines 3516 3531 +15
==========================================
+ Hits 2893 2908 +15
Misses 623 623 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
All required CI checks are now passing. The remaining failures are
Could a maintainer please review and merge? |
Your PR was only open for an hour, please be patient :) |
penelopeysm
left a comment
There was a problem hiding this comment.
I haven't yet looked at the actual extension code, only the general outline. The reason is because this PR really needs tests, otherwise we cannot merge it. If you look in test/varnamedtuple.jl you should see tests for different array and index types, there is OffsetArrays in there for example, which you can copy.
| Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" | ||
| Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de" | ||
| Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" | ||
| ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" |
There was a problem hiding this comment.
ComponentArrays should be a weak dependency rather than a dependency.
There was a problem hiding this comment.
That is, it should be under the [weakdeps] section rather than the [deps]. If you need a guide on extensions, you can find one at e.g. https://pkgdocs.julialang.org/dev/toml-files/#extensions-section
| DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" | ||
| FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" | ||
| InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" | ||
| JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" |
There was a problem hiding this comment.
JuliaFormatter should not be a project dependency. There are some helpful instructions on formatting in https://turinglang.org/docs/contributing/code-formatting/ but the TL;DR is that you shouldn't add it to the project environment, it should go in your global environment.
| using DynamicPPL | ||
|
|
||
| Aqua.test_all(DynamicPPL) | ||
| Aqua.test_all(DynamicPPL; stale_deps=false) |
There was a problem hiding this comment.
You wouldn't need to do this if you fix the main issues in the Project.toml.
|
All feedback addressed. Ready for re review. |
penelopeysm
left a comment
There was a problem hiding this comment.
Thank you! Can you please revert the formatting changes so that it's possible to see what the true diff is?
I think you are probably using the wrong version of JuliaFormatter. We use JuliaFormatter v1 not v2 (the docs page linked previously will have info on this).
02d8f33 to
20f3706
Compare
|
Reverted the unrelated formatting changes. Only formatted the files changed in this PR (ext/DynamicPPLComponentArraysExt.jl and test/varnamedtuple.jl) using JuliaFormatter v1. |
| template, | ||
| permissions::SetPermissions=AllowAll(), | ||
| ) where {S} | ||
| ax = getaxes(pa.data)[1] |
There was a problem hiding this comment.
Here you're accessing the first axis. I think that is fine as long as you are only doing ComponentVectors (note that ComponentArray in general can have arbitrary numbers of dimensions). I don't mind if you want to restrict this PR to vectors and not consider all possible arrays. However, in that case, I would suggest:
- Restrict the type signature to ComponentVector
- Use
only(getaxes(pa.data))since if it's a vector we can be sure that there's only one of them
In fact, though, it would likely be better if you used ComponentArrays.label2index instead: https://docs.sciml.ai/ComponentArrays/stable/api/#ComponentArrays.label2index-Tuple{ComponentVector,%20Any}. It will probably still not work for multidimensional arrays, but it allows you to avoid the ax[S].idx on the next line, which is a bit scary because it accesses a field that may or may not exist in the future.
| @testset "ComponentArrays" begin | ||
| ca = CA.ComponentArray(; a=1.0, b=2.0) | ||
| test_get_set(GetSetTestCase(@varname(x[1]), 1.0, ca, [])) | ||
| test_get_set(GetSetTestCase(@varname(x.a), 1.0, ca, [])) | ||
| test_get_set(GetSetTestCase(@varname(x.b), 2.0, ca, [])) | ||
| end |
There was a problem hiding this comment.
Thanks for adding some tests!
While I think the implementation is quite sensible, I do feel that these tests needs to be rather more thorough. Here you are very much still inside the happy path: for example, everything only points to one index, and you only test one element at a time. In fact it's quite likely that some of these tests would already have passed on current main. The interesting thing about ComponentArrays is that the property and indexing actually means the same thing, so we need some tests that exercise this property.
I'd really like to see some tests for:
- setting slices of a ComponentArray (you can use GetSetTestCase just like this)
- setting an index and then setting a property on the same VarNamedTuple (notice that the failure case in
ComponentArrays break #1230 involves setting both of these). In particular, it would be good to test an index and a label that point to the same thing (i.e. x[1] and x.a in the example here), to make sure that the second set overwrites the first set. - setting an index and then retrieving it with a property, and vice versa.
- ComponentArrays that contain other things. Say, a ComponentArray that contains Arrays, or a ComponentArray that contains a NamedTuple. For example, if
xis a ComponentArray andx.ais an array, are you able to setx.a[1], and are you able to setx[1][1]as well?
I recommend you try to find cases that cause your code to fail. It's fine if you find an edge case that you haven't handled: you can either fix it, or document it. That's still an improvement over the current, because currently the whole thing breaks! However, it's very important to know how confident we can be in the code. And we can't do that without being much more rigorous and exhaustive with the tests.
36af64d to
f805cd8
Compare
hardik-xi11
left a comment
There was a problem hiding this comment.
Please see if you can work on the following
also run the specific tests locally first(it will tell you in case there are any issues)
| template, | ||
| permissions::SetPermissions=AllowAll(), | ||
| ) where {S} | ||
| idx = only(label2index(template, S)) |
There was a problem hiding this comment.
S should be a String like String(S)
|
Hlw @penelopeysm I spent most of last night working through this and contributing to DynamicPPL has been one of the most exciting learning experiences for me so far. The way that VarNamedTuples weak dependencies and Julia's extension system come together has been genuinely fascinating. For some context on what I’ve learned the core fix involved overriding _setindex_optic!! in a weak dependency extension. The goal was to ensure that when a Property{S} optic (like x.a) is used on a ComponentVector it converts the symbol label to an integer index using ComponentArrays.getaxes and ax[S].idx. This then delegates to the existing index based dispatch making sure that both x[1] and x.a work correctly on the same VarNamedTuple without causing a crash. I’m not just here for the merge I’m committed to continuing to contribute to the TuringLang ecosystem long term. Thanks for your patience and for all the detailed guidance throughout this process. |
|
Thanks! Could you elaborate a bit about your choices here? I'm slightly confused as to why you used label2index as per my suggestion and then switched it back. I'm not fussed about whether you take my suggestion or not, but I'd like to know what happened or what went wrong. On the testing side, this is better but I still think we can push it further. Could you add something like ca = CA.ComponentArray(; a=1.0, b=2.0)
vnt = VarNamedTuple()
val = rand()
vnt = DynamicPPL.templated_setindex!!(vnt, val, @varname(x[1]), ca)
@test vnt[@varname(x[1])] == vnt[@varname(x.a)] == valThat's what I meant in my previous comment about
There's also
which looks something like this ca = CA.ComponentArray(; a=1.0, b=2.0)
vnt = VarNamedTuple()
val = rand()
vnt = DynamicPPL.templated_setindex!!(vnt, val, @varname(x[1]), ca)
@test vnt[@varname(x[1])] == vnt[@varname(x.a)] == val
val2 = rand()
vnt = DynamicPPL.templated_setindex!!(vnt, val2, @varname(x.a), ca)
@test vnt[@varname(x[1])] == vnt[@varname(x.a)] == val2You can't use GetSetTestCase on these because that only ever creates a single VarNamedTuple with a single field. |
Summary of ChangesI added an overload for the _getindex_optic alongside the existing _setindex_optic overload. After doing so I ran the cross access tests but they were failing because the get path also needs the same Property -> index conversion logic as the set path. To fix this, both overloads now use ComponentArrays.getaxes(pa.data)[1] and first(ax[S].idx) to resolve the label. When testing with label2index I tried various combinations like label2index(template, S) label2index(template, String(S)), and label2index(pa.data, S) but all of these resulted in runtime errors. Ultimately the combination of getaxes + ax[S].idx worked as expected. I’m happy to switch to a different approach if you can point me to the correct call site but I agree that the current method accesses an internal field. OutcomeAll 50,178 tests now pass locally including the two new cross access tests so the change seems to be stable. |
penelopeysm
left a comment
There was a problem hiding this comment.
Thanks for explaining!
I'm aware I keep harping on about this, but I think there is still some way to go in improving the test coverage. I tried this PR and I ran into this:
julia> @model function g()
x = ComponentArray(a = 1.0)
x.a ~ Normal()
end
julia> rand(g())
VarNamedTuple
└─ x => VarNamedTuple
└─ a => 0.2547158758482418And this suggests that the template is not actually being used at all. Now if you try to index into this with x.a it will be fine, but if you try to index with x[1] it will not be fine.
At a lower level, this is the same thing as
julia> using DynamicPPL: templated_setindex!!, VarNamedTuple
julia> vnt = VarNamedTuple()
VarNamedTuple()
julia> ca = ComponentArray(a = 1.0)
ComponentVector{Float64}(a = 1.0)
julia> templated_setindex!!(vnt, 1.0, @varname(x.a), ca)
VarNamedTuple
└─ x => VarNamedTuple
└─ a => 1.0If you look at the existing test you added:
val = rand()
vnt = VarNamedTuple()
vnt = DynamicPPL.templated_setindex!!(vnt, val, @varname(x[1]), ca)
@test vnt[@varname(x[1])] == vnt[@varname(x.a)] == val
val2 = rand()
vnt = DynamicPPL.templated_setindex!!(vnt, val2, @varname(x.a), ca)
@test vnt[@varname(x[1])] == vnt[@varname(x.a)] == val2Notice that in this sequence of events you set x[1] first before setting x.a. In that case, when you set x[1] it will indeed use the template and all is well. However, if you were to tweak this slightly and set x.a first, it will break:
val = rand()
vnt = VarNamedTuple()
vnt = DynamicPPL.templated_setindex!!(vnt, val, @varname(x.a), ca)
@test vnt[@varname(x[1])] == vnt[@varname(x.a)] == val
#=
Test threw exception
Expression: vnt[#= REPL[25]:1 =# @varname(x[1])] == vnt[#= REPL[25]:1 =# @varname(x.a)] == val
MethodError: no method matching getindex(::VarNamedTuple{(:a,), Tuple{Float64}}, ::Int64)
The function `getindex` exists, but no method is defined for this combination of argument types.
=#|
As I said in my previous comment, I really suggest approaching the unit tests in as adversarial a manner as possible. Write down all possible combinations of gets and sets that you can think of, and make sure that they all work exactly the way you expect them to! |
|
Finally, once you fix that, I'd like for another aspect of testing that we haven't yet discussed, namely the Start by setting up a VNT: julia> ca = ComponentArray(a = 1.0); vnt = VarNamedTuple(); val = rand()
0.9536133882970798
julia> vnt = templated_setindex!!(vnt, val, @varname(x[1]), ca)
VarNamedTuple
└─ x => PartialArray size=(1,) data::ComponentVector{Float64, Vector{Float64}, Tuple{Axis{(a = 1,)}}}
└─ (1,) => 0.9536133882970798Then try to set the value again but with the julia> vnt = DynamicPPL.VarNamedTuples.templated_setindex_no_overwrite!!(vnt, val, @varname(x[1]), ca)
ERROR: MustNotOverwriteError: Attempted to set a value for x[1], but a value already existed. This indicates that a value is being set twice (e.g. if the same variable occurs in a model twice).
[...]Now, this currently works correctly (we do expect an error to be thrown in this case: this is a mechanism to stop people from accidentally using the same variable twice). But that's only because I used |
|
@penelopeysm With this update all four key access scenarios now work correctly
Additionally all error cases now trigger correctly ensuring proper handling for attempts like these
I’ve also run a full suite of 50,178 tests locally and everything passed successfully. Thanks again for your help and patience during this process. |
Great, but could you add those tests to the package's test suite? |
penelopeysm
left a comment
There was a problem hiding this comment.
Please add the tests :)
The extension itself looks good to me.
| ax = ComponentArrays.getaxes(template)[1] | ||
| idx = first(ax[S].idx) | ||
| index_optic = AbstractPPL.Index((idx,), NamedTuple(), optic.child) |
There was a problem hiding this comment.
These three lines are repeated across the methods for make_leaf, setindex_optic and getindex_optic. Consider making a helper function for it to avoid the duplication?
There was a problem hiding this comment.
These three lines are repeated across the methods for make_leaf, setindex_optic and getindex_optic. Consider making a helper function for it to avoid the duplication?
extracted as _property_to_index helper function.
| index_optic = AbstractPPL.Index((idx,), NamedTuple(), optic.child) | ||
| return make_leaf(value, index_optic, template) | ||
| else | ||
| return invoke( |
There was a problem hiding this comment.
I think this code looks good to me. It would probably be useful to add a comment or two explaining the call to invoke() because it's not obvious at first glance why this is needed.
There was a problem hiding this comment.
Added a comment explaining why invoke() is needed to avoid infinite recursion.
Done. Added tests for property first access cross access (set by index/get by property and vice versa) and all four MustNotOverwrite combinations in test/varnamedtuple.jl. |
|
@penelopeysm make_leaf for Property{S} on ComponentVector - converts property label to integer index so that PartialArray is created correctly when property access comes first A shared _property_to_index helper avoids duplication across all three. invoke() is used in make_leaf to explicitly call the AbstractArray method and avoid infinite recursion with our own overload. Tests added: cross access (set by index/get by property and vice versa) property first access slice setting array valued fields and all four MustNotOverwrite combinations. |
penelopeysm
left a comment
There was a problem hiding this comment.
Thank you! I'm going to add some more tests and then merge.
Thank you @penelopeysm for the thorough reviews and patience throughout this process. The detailed feedback at every step really helped me understand the codebase deeply. Looking forward to contributing more. |
Fixes #1230
What was the problem?
When using a
ComponentArrayas a random variable in a model, mixingindex-based access (
x[1]) and property based access (x.b) in thesame model would crash with a
MethodError:Why did it crash?
When
x[1] ~ Normal()is evaluated first, DynamicPPL storesxinternally as a
PartialArraybacked by aComponentVector. Whenx.b ~ Normal()is evaluated next, DynamicPPL tries to call_setindex_optic!!with aPropertyoptic (.b) on thatPartialArraybut no method existed for this combination. Only
Indexoptics weresupported on a
PartialArray.Additionally,
make_leafwas not routingComponentArrayinto thecorrect
AbstractArraycode path due to Julia's method dispatch rules.How did I fix it?
Added
ext/DynamicPPLComponentArraysExt.jlwith two method overloads:1.
make_leafforComponentArraytemplatesExplicitly routes
ComponentArrayinto the existingAbstractArraypathso that
PartialArrayis created correctly with axes preserved.2.
_setindex_optic!!forPartialArray{ComponentArray}+PropertyopticConverts the named property (e.g.
.b) to its integer index usingComponentArrays' axis system (
ax[S].idx), then delegates to theexisting
Index-based method.Also updated
Project.tomlto addComponentArraysto[weakdeps],[extensions], and[compat].Result
x[1] ~ Normal()x.a ~ Normal()x[1] ~ Normal(); x.b ~ Normal()