Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
3f871cc
Add ComponentArraysExt to fix mixed index/property access
ArpanC6 Apr 27, 2026
7bbe5ba
Fix formatting
ArpanC6 Apr 27, 2026
226c065
Add ComponentArrays to test/Project.toml
ArpanC6 Apr 27, 2026
75b6f62
Add ComponentArrays to [extras] in Project.toml
ArpanC6 Apr 27, 2026
c23056d
fix: disable Aqua stale_deps check for ComponentArrays weakdep
ArpanC6 Apr 27, 2026
20f3706
fix: move ComponentArrays to weakdeps, add tests, revert Aqua change
ArpanC6 Apr 27, 2026
d9dac33
style: format ComponentArrays test with JuliaFormatter v1
ArpanC6 Apr 27, 2026
f805cd8
fix: remove make_leaf overload, use label2index, expand tests
ArpanC6 Apr 27, 2026
fc0e846
test: expand ComponentArrays tests with cross-access and array-valued…
ArpanC6 Apr 27, 2026
5317aa4
style: remove trailing newline
ArpanC6 Apr 27, 2026
18424f1
fix: use label2index and ComponentVector, remove make_leaf overload
ArpanC6 Apr 27, 2026
d8ed680
fix: use label2index and ComponentVector, remove make_leaf overload
ArpanC6 Apr 27, 2026
180e649
fix: restore ComponentArrays to weakdeps
ArpanC6 Apr 27, 2026
799b593
fix: add ComponentArrays to compat
ArpanC6 Apr 27, 2026
3799736
style: format ComponentArraysExt with JuliaFormatter v1
ArpanC6 Apr 27, 2026
61b66ca
fix: use parent(template) and String(S) in label2index call
ArpanC6 Apr 27, 2026
f617efc
fix: remove stray entries from .gitignore
ArpanC6 Apr 27, 2026
9631c58
fix: use S directly in label2index instead of String(S)
ArpanC6 Apr 27, 2026
7ed60b4
fix: use template directly in label2index, remove parent()
ArpanC6 Apr 27, 2026
1b26bed
fix: use pa.data in label2index call
ArpanC6 Apr 27, 2026
02adfc1
fix: use ComponentArrays.getaxes and ax[S].idx for property lookup
ArpanC6 Apr 27, 2026
4857977
style: add trailing newline
ArpanC6 Apr 27, 2026
28048c3
fix: use first(ax[S].idx) to extract integer from range
ArpanC6 Apr 27, 2026
84990d8
fix: remove invalid cross-access tests
ArpanC6 Apr 28, 2026
e261e53
fix: add _getindex_optic overload for Property optic on ComponentVect…
ArpanC6 Apr 28, 2026
8766066
style: reformat with JuliaFormatter v1
ArpanC6 Apr 28, 2026
159513b
fix: handle Property optic in make_leaf for ComponentVector
ArpanC6 Apr 29, 2026
7ebae5b
style: reformat with BlueStyle
ArpanC6 Apr 29, 2026
fd6392c
refactor: extract helper function, add property-first and MustNotOver…
ArpanC6 Apr 30, 2026
5c5fb92
style: format ext with JuliaFormatter v1
ArpanC6 Apr 30, 2026
6f12d3b
Merge remote-tracking branch 'origin/main' into fix/componentarrays-ext
penelopeysm May 1, 2026
66004da
Bump patch
penelopeysm May 1, 2026
f9849d9
Use loops in tests
penelopeysm May 1, 2026
841ffde
Add more tests
penelopeysm May 1, 2026
0fbc196
test skeleton as well
penelopeysm May 1, 2026
5e3a1ca
Merge remote-tracking branch 'origin/main' into fix/componentarrays-ext
penelopeysm May 1, 2026
28a144f
Fix bad merge
penelopeysm May 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Comment thread
penelopeysm marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ Manifest.toml

benchmarks/*.json
LocalPreferences.toml

2 changes: 2 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# 0.41.7

Enable usage of `ComponentVector`s on the left-hand side of tilde-statements.

Accessing a nonexistent variable in a `VarNamedTuple` now throws a `KeyError` with the original `VarName`, instead of an opaque `type NamedTuple has no field ...` error.

# 0.41.6
Expand Down
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[weakdeps]
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
Expand All @@ -39,6 +40,7 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"

[extensions]
DynamicPPLComponentArraysExt = ["ComponentArrays"]
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
DynamicPPLForwardDiffExt = ["ForwardDiff"]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
Expand All @@ -55,6 +57,7 @@ BangBang = "0.4.1"
Bijectors = "0.15.17"
Chairmarks = "1.3.1"
Compat = "4"
ComponentArrays = "0.15"
ConstructionBase = "1.5.4"
DifferentiationInterface = "0.6.41, 0.7"
Distributions = "0.25"
Expand Down
64 changes: 64 additions & 0 deletions ext/DynamicPPLComponentArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
module DynamicPPLComponentArraysExt
using DynamicPPL: DynamicPPL
using DynamicPPL.VarNamedTuples:
PartialArray,
AllowAll,
SetPermissions,
_setindex_optic!!,
_getindex_optic,
make_leaf,
make_leaf_singleindex,
_is_multiindex,
make_leaf_multiindex
using ComponentArrays: ComponentArrays, ComponentArray, ComponentVector
using AbstractPPL

# Helper: convert a Property optic label S to an integer Index optic
function _property_to_index(
template::ComponentVector, optic::AbstractPPL.Property{S}
) where {S}
ax = ComponentArrays.getaxes(template)[1]
idx = first(ax[S].idx)
return AbstractPPL.Index((idx,), NamedTuple(), optic.child)
end

function DynamicPPL.VarNamedTuples.make_leaf(
value, optic::AbstractPPL.Property{S}, template::ComponentVector
) where {S}
return if optic.child isa AbstractPPL.Iden
index_optic = _property_to_index(template, optic)
make_leaf(value, index_optic, template)
else
# This branch is needed to handle nested axes in ComponentArrays: the idea is that
# if x is e.g. ComponentArray(a=(b=1)) and we are trying to set `x.a.b`, then we
# first index into `x.a` to get the slice of the ComponentArray. The easiest way to
# handle this is to call the default method.
invoke(
make_leaf,
Tuple{Any,AbstractPPL.Property{S},AbstractArray},
value,
optic,
template,
)
end
end

function DynamicPPL.VarNamedTuples._setindex_optic!!(
pa::PartialArray{<:Any,<:Any,<:ComponentVector},
value,
optic::AbstractPPL.Property{S},
template,
permissions::SetPermissions=AllowAll(),
) where {S}
index_optic = _property_to_index(pa.data, optic)
return _setindex_optic!!(pa, value, index_optic, template, permissions)
Comment thread
penelopeysm marked this conversation as resolved.
end
Comment thread
penelopeysm marked this conversation as resolved.

function DynamicPPL.VarNamedTuples._getindex_optic(
pa::PartialArray{<:Any,<:Any,<:ComponentVector}, optic::AbstractPPL.Property{S}, orig_vn
) where {S}
index_optic = _property_to_index(pa.data, optic)
return _getindex_optic(pa, index_optic, orig_vn)
end

end
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,14 @@ Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Comment thread
penelopeysm marked this conversation as resolved.
[compat]
ADTypes = "1"
AbstractMCMC = "5.10"
AbstractPPL = "0.14"
Accessors = "0.1"
Aqua = "0.8"
ComponentArrays = "0.15"
BangBang = "0.4"
Bijectors = "0.15.17"
Chairmarks = "1"
Expand Down
71 changes: 71 additions & 0 deletions test/varnamedtuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ using BangBang: setindex!!, empty!!
using DimensionalData: DimensionalData as DD
using InvertedIndices: InvertedIndices as II
using OffsetArrays: OffsetArrays as OA
using ComponentArrays: ComponentArrays as CA

struct GetSetTestCase
# The VarName being set.
Expand Down Expand Up @@ -309,6 +310,60 @@ Base.size(st::SizedThing) = st.size
)
end

@testset "ComponentArray" begin
ca = CA.ComponentArray(; a=1.0, b=2.0)
test_get_set(GetSetTestCase(@varname(x[1]), 1.0, ca, []))
test_get_set(GetSetTestCase(@varname(x[2]), 2.0, ca, []))
test_get_set(GetSetTestCase(@varname(x.a), 1.0, ca, []))
test_get_set(GetSetTestCase(@varname(x.b), 2.0, ca, []))
test_get_set(GetSetTestCase(@varname(x[1:2]), [1.0, 2.0], ca, []))

# ComponentVector with array-valued fields
ca3 = CA.ComponentArray(; a=[1.0, 2.0], b=[3.0, 4.0])
test_get_set(GetSetTestCase(@varname(x.a), [1.0, 2.0], ca3, []))
test_get_set(GetSetTestCase(@varname(x.b), [3.0, 4.0], ca3, []))
test_get_set(GetSetTestCase(@varname(x.a[1]), 1.0, ca3, []))

# with nested fields
ca4 = CA.ComponentArray(; a=(; x=1.0, y=2.0))
test_get_set(GetSetTestCase(@varname(x.a.x), 10.0, ca4, []))
test_get_set(GetSetTestCase(@varname(x.a.y), 20.0, ca4, []))
test_get_set(GetSetTestCase(@varname(x[1]), 10.0, ca4, []))
test_get_set(GetSetTestCase(@varname(x[2]), 20.0, ca4, []))

# Mixed index/property access
val = rand()
vns = (@varname(x[1]), @varname(x.a))
for set_vn in vns
vnt = DynamicPPL.templated_setindex!!(VarNamedTuple(), val, set_vn, ca)
for get_vn in vns
@test vnt[get_vn] == val
end
end

# Check that setting one and overwriting with the other works
val = rand()
new_val = val + 1
for (vn1, vn2) in
((@varname(x[1]), @varname(x.a)), (@varname(x.a), @varname(x[1])))
vnt = VarNamedTuple()
vnt = DynamicPPL.templated_setindex!!(vnt, val, vn1, ca)
@test vnt[vn1] == vnt[vn2] == val # Sanity check.
vnt = DynamicPPL.templated_setindex!!(vnt, new_val, vn2, ca)
@test vnt[vn1] == vnt[vn2] == new_val
end

# Check that MustNotOverwrite is respected.
for vn1 in vns
vnt = DynamicPPL.templated_setindex!!(VarNamedTuple(), val, vn1, ca)
for vn2 in vns
@test_throws MustNotOverwriteError DynamicPPL.VarNamedTuples.templated_setindex_no_overwrite!!(
vnt, new_val, vn2, ca
)
end
end
end

@testset "InvertedIndices" begin
# TODO(penelopeysm): Templated setindex fails for II.Not(). I really don't know
# why but there is some failure in constant propagation when setting the mask
Expand Down Expand Up @@ -2029,6 +2084,15 @@ Base.size(st::SizedThing) = st.size
x[2:3] := SizedThing((2,))
end
@test densify!!(vnt) == vnt

# Check with ComponentArrays
x = CA.ComponentArray(; a=0.0, b=0.0)
vnt = @vnt begin
@template x
x.a := 1.0
x.b := 2.0
end
@test densify!!(vnt) == VarNamedTuple(; x=CA.ComponentArray(; a=1.0, b=2.0))
end

@testset "skeleton" begin
Expand Down Expand Up @@ -2147,6 +2211,13 @@ Base.size(st::SizedThing) = st.size
end
v12s = VarNamedTuple(; x=DD.DimArray(fill(nothing, 2, 3), (:a, :b)))
test_skeleton(v12, v12s)

v13 = @vnt begin
@template x = CA.ComponentArray(; a=0.0, b=0.0)
x.a := 1.0
end
v13s = VarNamedTuple(; x=CA.ComponentArray(; a=nothing, b=nothing))
test_skeleton(v13, v13s)
end
end

Expand Down
Loading