Skip to content

Commit a9bef21

Browse files
committed
feat(functools): add Placeholder support and tests for functools.partial
Re-export `functools.Placeholder` (Python 3.14+) from `optree.functools` and add comprehensive tests verifying that `optree.functools.partial` works consistently with stdlib's `functools.partial` when Placeholders are used in positional arguments.
1 parent bd32248 commit a9bef21

3 files changed

Lines changed: 271 additions & 0 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1616
- Add `attrs` integration module `optree.integrations.attrs` with `field`, `define`, `frozen`, `mutable`, `make_class`, `register_node`, and `AttrsEntry` by [@XuehaiPan](https://github.com/XuehaiPan) in [#273](https://github.com/metaopt/optree/pull/273).
1717
- Add `optree.dataclasses.register_node` to register existing dataclasses as pytree nodes by [@XuehaiPan](https://github.com/XuehaiPan) in [#273](https://github.com/metaopt/optree/pull/273).
1818
- Extend `GetAttrEntry` to support dotted attribute paths for traversing nested attributes (e.g., `a.b.c`) by [@XuehaiPan](https://github.com/XuehaiPan).
19+
- Add `functools.Placeholder` support and re-export for `optree.functools.partial` (Python 3.14+) by [@XuehaiPan](https://github.com/XuehaiPan) in [#276](https://github.com/metaopt/optree/pull/276).
1920

2021
### Changed
2122

optree/functools.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from __future__ import annotations
1818

1919
import functools
20+
import sys
2021
from typing import TYPE_CHECKING, Any, Callable, ClassVar
2122
from typing_extensions import Self # Python 3.11+
2223

@@ -29,12 +30,18 @@
2930
if TYPE_CHECKING:
3031
from optree.accessors import PyTreeEntry
3132

33+
if sys.version_info >= (3, 14):
34+
from functools import Placeholder # pylint: disable=unused-import
35+
3236

3337
__all__ = [
3438
'partial',
3539
'reduce',
3640
]
3741

42+
if sys.version_info >= (3, 14):
43+
__all__ += ['Placeholder']
44+
3845

3946
class _HashablePartialShim:
4047
"""A shim object that delegates :meth:`__call__`, :meth:`__eq__`, and :meth:`__hash__` to a :func:`functools.partial` object.""" # pylint: disable=line-too-long
@@ -111,6 +118,17 @@ class partial( # noqa: N801 # pylint: disable=invalid-name,too-few-public-metho
111118
112119
Had we passed :func:`operator.add` to ``call_func_on_cuda`` directly, it would have resulted in
113120
a :class:`TypeError` or :class:`AttributeError`.
121+
122+
On Python 3.14+, :data:`functools.Placeholder` can be used to reserve positional argument slots:
123+
124+
>>> from optree.functools import partial, Placeholder # doctest: +SKIP
125+
>>> import operator
126+
>>> sub_from = partial(operator.sub, Placeholder, 3) # doctest: +SKIP
127+
>>> sub_from(10) # doctest: +SKIP
128+
7
129+
130+
:data:`~functools.Placeholder` objects are treated as leaves in the pytree and their identity is
131+
preserved through flatten/unflatten round-trips.
114132
"""
115133

116134
__slots__: ClassVar[tuple[()]] = ()

tests/test_functools.py

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
# pylint: disable=missing-function-docstring,invalid-name
1717

1818
import functools
19+
import sys
20+
21+
import pytest
1922

2023
import optree
2124
from helpers import GLOBAL_NAMESPACE, parametrize
@@ -80,3 +83,252 @@ def test_partial_func_attribute_has_stable_hash():
8083
assert fn == p1.func # pylint: disable=comparison-with-callable
8184
assert p1.func == p2.func
8285
assert hash(p1.func) == hash(p2.func)
86+
87+
88+
def test_partial_placeholder_roundtrip():
89+
if sys.version_info < (3, 14):
90+
pytest.skip('Placeholder requires Python 3.14+')
91+
92+
ph = functools.Placeholder
93+
94+
def f(*args, **kwargs):
95+
return args, kwargs
96+
97+
p1 = optree.functools.partial(f, ph, 42)
98+
leaves, treespec = optree.tree_flatten(p1)
99+
p2 = optree.tree_unflatten(treespec, leaves)
100+
assert p2.func == p1.func
101+
assert p2.args == p1.args
102+
assert p2.args[0] is ph
103+
assert p2.keywords == p1.keywords
104+
assert p2('x') == f('x', 42)
105+
106+
107+
def test_partial_placeholder_call_after_roundtrip():
108+
if sys.version_info < (3, 14):
109+
pytest.skip('Placeholder requires Python 3.14+')
110+
111+
ph = functools.Placeholder
112+
113+
def f(*args, **kwargs):
114+
return args, kwargs
115+
116+
p1 = optree.functools.partial(f, ph, 42)
117+
leaves, treespec = optree.tree_flatten(p1)
118+
p2 = optree.tree_unflatten(treespec, leaves)
119+
120+
# Fill placeholder
121+
assert p2('x') == (('x', 42), {})
122+
123+
# Extra args beyond placeholder
124+
assert p2('x', 'y') == (('x', 42, 'y'), {})
125+
126+
# Missing placeholder arg
127+
with pytest.raises(TypeError, match='missing positional arguments'):
128+
p2()
129+
130+
131+
def test_partial_multiple_placeholders_roundtrip():
132+
if sys.version_info < (3, 14):
133+
pytest.skip('Placeholder requires Python 3.14+')
134+
135+
ph = functools.Placeholder
136+
137+
def f(*args, **kwargs):
138+
return args, kwargs
139+
140+
p1 = optree.functools.partial(f, ph, 42, ph, 99)
141+
leaves, treespec = optree.tree_flatten(p1)
142+
p2 = optree.tree_unflatten(treespec, leaves)
143+
assert p2.args == (ph, 42, ph, 99)
144+
assert p2.args[0] is ph
145+
assert p2.args[2] is ph
146+
assert p2('a', 'b') == (('a', 42, 'b', 99), {})
147+
148+
149+
def test_partial_placeholder_with_keywords():
150+
if sys.version_info < (3, 14):
151+
pytest.skip('Placeholder requires Python 3.14+')
152+
153+
ph = functools.Placeholder
154+
155+
def f(*args, **kwargs):
156+
return args, kwargs
157+
158+
p1 = optree.functools.partial(f, ph, 42, key='value')
159+
leaves, treespec = optree.tree_flatten(p1)
160+
p2 = optree.tree_unflatten(treespec, leaves)
161+
assert p2.args == (ph, 42)
162+
assert p2.keywords == {'key': 'value'}
163+
assert p2('x') == (('x', 42), {'key': 'value'})
164+
165+
166+
def test_partial_placeholder_is_leaf():
167+
if sys.version_info < (3, 14):
168+
pytest.skip('Placeholder requires Python 3.14+')
169+
170+
ph = functools.Placeholder
171+
172+
def f(*args, **kwargs):
173+
return args, kwargs
174+
175+
p = optree.functools.partial(f, ph, 42)
176+
leaves = optree.tree_leaves(p)
177+
assert ph in leaves
178+
assert 42 in leaves
179+
180+
181+
def test_partial_placeholder_tree_map():
182+
if sys.version_info < (3, 14):
183+
pytest.skip('Placeholder requires Python 3.14+')
184+
185+
ph = functools.Placeholder
186+
187+
def f(*args, **kwargs):
188+
return args, kwargs
189+
190+
p1 = optree.functools.partial(f, ph, 42)
191+
192+
# Identity tree_map preserves Placeholder
193+
p2 = optree.tree_map(lambda x: x, p1)
194+
assert p2.args[0] is ph
195+
assert p2.args[1] == 42
196+
assert p2('test') == (('test', 42), {})
197+
198+
199+
def test_partial_placeholder_in_larger_tree():
200+
if sys.version_info < (3, 14):
201+
pytest.skip('Placeholder requires Python 3.14+')
202+
203+
ph = functools.Placeholder
204+
205+
def f(*args, **kwargs):
206+
return args, kwargs
207+
208+
p = optree.functools.partial(f, ph, 42)
209+
tree = {'fn': p, 'data': [1, 2, 3]}
210+
leaves, treespec = optree.tree_flatten(tree)
211+
tree2 = optree.tree_unflatten(treespec, leaves)
212+
assert tree2['fn'].args[0] is ph
213+
assert tree2['fn']('test') == (('test', 42), {})
214+
assert tree2['data'] == [1, 2, 3]
215+
216+
217+
def test_partial_wrapping_stdlib_partial_with_placeholder():
218+
if sys.version_info < (3, 14):
219+
pytest.skip('Placeholder requires Python 3.14+')
220+
221+
ph = functools.Placeholder
222+
223+
def f(*args, **kwargs):
224+
return args, kwargs
225+
226+
stdlib_p = functools.partial(f, ph, 42)
227+
op1 = optree.functools.partial(stdlib_p, 'extra')
228+
229+
# Anti-merge: outer args are separate
230+
assert op1.args == ('extra',)
231+
assert op1() == (('extra', 42), {})
232+
233+
# Roundtrip
234+
leaves, treespec = optree.tree_flatten(op1)
235+
op2 = optree.tree_unflatten(treespec, leaves)
236+
assert op2.args == ('extra',)
237+
assert op2() == (('extra', 42), {})
238+
239+
240+
def test_partial_wrapping_stdlib_partial_with_placeholder_no_extra_args():
241+
if sys.version_info < (3, 14):
242+
pytest.skip('Placeholder requires Python 3.14+')
243+
244+
ph = functools.Placeholder
245+
246+
def f(*args, **kwargs):
247+
return args, kwargs
248+
249+
stdlib_p = functools.partial(f, ph, 42)
250+
op1 = optree.functools.partial(stdlib_p)
251+
assert op1.args == ()
252+
assert op1('hello') == (('hello', 42), {})
253+
254+
# Roundtrip
255+
leaves, treespec = optree.tree_flatten(op1)
256+
op2 = optree.tree_unflatten(treespec, leaves)
257+
assert op2('hello') == (('hello', 42), {})
258+
259+
260+
def test_partial_nested_optree_partial_with_placeholder():
261+
if sys.version_info < (3, 14):
262+
pytest.skip('Placeholder requires Python 3.14+')
263+
264+
ph = functools.Placeholder
265+
266+
def f(*args, **kwargs):
267+
return args, kwargs
268+
269+
inner = optree.functools.partial(f, ph, 42)
270+
outer = optree.functools.partial(inner, 'extra')
271+
272+
# Anti-merge behavior
273+
assert outer.args == ('extra',)
274+
assert outer() == (('extra', 42), {})
275+
276+
# Roundtrip of outer
277+
leaves, treespec = optree.tree_flatten(outer)
278+
outer2 = optree.tree_unflatten(treespec, leaves)
279+
assert outer2() == (('extra', 42), {})
280+
281+
282+
def test_partial_trailing_placeholder_rejection():
283+
if sys.version_info < (3, 14):
284+
pytest.skip('Placeholder requires Python 3.14+')
285+
286+
ph = functools.Placeholder
287+
288+
def f(*args, **kwargs):
289+
return args, kwargs
290+
291+
with pytest.raises(TypeError, match='trailing Placeholders are not allowed'):
292+
optree.functools.partial(f, 42, ph)
293+
294+
with pytest.raises(TypeError, match='trailing Placeholders are not allowed'):
295+
optree.functools.partial(f, ph)
296+
297+
with pytest.raises(TypeError, match='trailing Placeholders are not allowed'):
298+
optree.functools.partial(f, ph, 1, ph)
299+
300+
301+
def test_partial_keyword_placeholder_rejection():
302+
if sys.version_info < (3, 14):
303+
pytest.skip('Placeholder requires Python 3.14+')
304+
305+
ph = functools.Placeholder
306+
307+
def f(*args, **kwargs):
308+
return args, kwargs
309+
310+
with pytest.raises(TypeError, match='Placeholder'):
311+
optree.functools.partial(f, kw=ph)
312+
313+
314+
def test_partial_repr_with_placeholder():
315+
if sys.version_info < (3, 14):
316+
pytest.skip('Placeholder requires Python 3.14+')
317+
318+
ph = functools.Placeholder
319+
320+
def f(*args, **kwargs):
321+
return args, kwargs
322+
323+
p = optree.functools.partial(f, ph, 42)
324+
r = repr(p)
325+
assert 'Placeholder' in r
326+
assert '42' in r
327+
328+
329+
def test_partial_placeholder_reexport():
330+
if sys.version_info < (3, 14):
331+
pytest.skip('Placeholder requires Python 3.14+')
332+
333+
assert hasattr(optree.functools, 'Placeholder')
334+
assert optree.functools.Placeholder is functools.Placeholder

0 commit comments

Comments
 (0)