When using the functional API with optimizer adamw with the mask parameter specified, the expectation is that update is applied with weight decay skipped for the masked parameters. Instead, update fails with 'AttributeError: 'MaskedNode' object has no attribute 'add''.
The comment for 'MaskedNode' states "This node is ignored when mapping functions across the tree e.g. using :func:pytree.tree_map since it is a container without children. It can therefore be used to mask out parts of a tree." However, this does not appear to be the case.
File "torchopt_test.py", line 230, in <module>
functorch_original.test_train_step_fn(weights, opt_state, points, labels)
File "torchopt_test.py", line 160, in test_train_step_fn
loss, (weights, opt_state) = self.train_step_fn((weights, opt_state), points, labels)
File "torchopt_test.py", line 154, in train_step_fn
updates, new_opt_state = optimizer.update(grads, opt_state, params=weights, inplace=False)
File "/usr/local/lib/python3.10/dist-packages/torchopt/combine.py", line 92, in update_fn
flat_updates, state = inner.update(flat_updates, state, params=flat_params, inplace=inplace)
File "/usr/local/lib/python3.10/dist-packages/torchopt/base.py", line 196, in update_fn
updates, new_s = fn(updates, s, params=params, inplace=inplace)
File "/usr/local/lib/python3.10/dist-packages/torchopt/transform/add_decayed_weights.py", line 132, in update_fn
new_masked_updates, new_inner_state = inner.update(
File "/usr/local/lib/python3.10/dist-packages/torchopt/transform/add_decayed_weights.py", line 243, in update_fn
updates = tree_map(f, params, updates)
File "/usr/local/lib/python3.10/dist-packages/torchopt/transform/utils.py", line 65, in tree_map_flat
return flat_arg.__class__(map(fn, flat_arg, *flat_args)) # type: ignore[call-arg]
File "/usr/local/lib/python3.10/dist-packages/torchopt/transform/utils.py", line 63, in fn
return func(x, *xs) if x is not None else None
File "/usr/local/lib/python3.10/dist-packages/torchopt/transform/add_decayed_weights.py", line 241, in f
return g.add(p, alpha=weight_decay) if g is not None else g
AttributeError: 'MaskedNode' object has no attribute 'add'
The expectation is that when a mask is supplied to adamw, update is successful and weight decay is skipped for the masked parameters.
Required prerequisites
What version of TorchOpt are you using?
0.7.3
System information
pip install torchopt
Python 3.10.12
3.10.12 (main, Jul 29 2024, 16:56:48) [GCC 11.4.0] linux
0.7.3 2.5.0a0+872d972e41.nv24.08 2.5.0a0+872d972e41.nv24.08
Problem description
When using the functional API with optimizer adamw with the mask parameter specified, the expectation is that update is applied with weight decay skipped for the masked parameters. Instead, update fails with 'AttributeError: 'MaskedNode' object has no attribute 'add''.
The comment for 'MaskedNode' states "This node is ignored when mapping functions across the tree e.g. using :func:
pytree.tree_mapsince it is a container without children. It can therefore be used to mask out parts of a tree." However, this does not appear to be the case.Reproducible example code
The Python snippets:
Command lines:
Extra dependencies:
Steps to reproduce:
Traceback
File "torchopt_test.py", line 230, in <module> functorch_original.test_train_step_fn(weights, opt_state, points, labels) File "torchopt_test.py", line 160, in test_train_step_fn loss, (weights, opt_state) = self.train_step_fn((weights, opt_state), points, labels) File "torchopt_test.py", line 154, in train_step_fn updates, new_opt_state = optimizer.update(grads, opt_state, params=weights, inplace=False) File "/usr/local/lib/python3.10/dist-packages/torchopt/combine.py", line 92, in update_fn flat_updates, state = inner.update(flat_updates, state, params=flat_params, inplace=inplace) File "/usr/local/lib/python3.10/dist-packages/torchopt/base.py", line 196, in update_fn updates, new_s = fn(updates, s, params=params, inplace=inplace) File "/usr/local/lib/python3.10/dist-packages/torchopt/transform/add_decayed_weights.py", line 132, in update_fn new_masked_updates, new_inner_state = inner.update( File "/usr/local/lib/python3.10/dist-packages/torchopt/transform/add_decayed_weights.py", line 243, in update_fn updates = tree_map(f, params, updates) File "/usr/local/lib/python3.10/dist-packages/torchopt/transform/utils.py", line 65, in tree_map_flat return flat_arg.__class__(map(fn, flat_arg, *flat_args)) # type: ignore[call-arg] File "/usr/local/lib/python3.10/dist-packages/torchopt/transform/utils.py", line 63, in fn return func(x, *xs) if x is not None else None File "/usr/local/lib/python3.10/dist-packages/torchopt/transform/add_decayed_weights.py", line 241, in f return g.add(p, alpha=weight_decay) if g is not None else g AttributeError: 'MaskedNode' object has no attribute 'add'Expected behavior
The expectation is that when a mask is supplied to adamw, update is successful and weight decay is skipped for the masked parameters.
Additional context
No response