Skip to content

Commit bd32248

Browse files
committed
docs(README): explain why dict keys are sorted during flattening
1 parent 4496033 commit bd32248

1 file changed

Lines changed: 45 additions & 0 deletions

File tree

README.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,51 @@ The same applies to [`collections.defaultdict`](https://docs.python.org/3/librar
505505
([1, 2, 3], PyTreeSpec({'a': [*, *], 'b': [*]}))
506506
```
507507
508+
Sorting ensures that equal dictionaries always flatten to the same leaf sequence, regardless of insertion order. This is critical for operations that rely on positional correspondence between leaves. Consider two parameter `dict`s that are equal but constructed in different orders:
509+
510+
```python
511+
>>> import numpy as np
512+
>>> params1 = {'weight': np.array([[1.0, 2.0], [3.0, 4.0]]), 'bias': np.array([5.0, 6.0])}
513+
>>> params2 = {'bias': np.array([5.0, 6.0]), 'weight': np.array([[1.0, 2.0], [3.0, 4.0]])}
514+
>>> optree.tree_all(optree.tree_map(np.allclose, params1, params2))
515+
True
516+
```
517+
518+
Because `tree_map` zips leaves positionally, sorted keys guarantee correct element-wise operations:
519+
520+
```python
521+
>>> optree.tree_map(lambda x, y: x - y, params1, params2)
522+
{
523+
'weight': array([[0., 0.],
524+
[0., 0.]]),
525+
'bias': array([0., 0.])
526+
}
527+
```
528+
529+
The same applies to `tree_ravel`, which concatenates all leaves into a single 1D array:
530+
531+
```python
532+
>>> from optree.integrations.numpy import tree_ravel
533+
>>> tree_ravel(params1)[0]
534+
array([5., 6., 1., 2., 3., 4.]) # 'bias' before 'weight' (sorted)
535+
>>> tree_ravel(params2)[0]
536+
array([5., 6., 1., 2., 3., 4.]) # same order, despite different insertion order
537+
```
538+
539+
Without sorting, insertion order would silently corrupt the results. Here is a counterexample using `dict_insertion_ordered`:
540+
541+
```python
542+
>>> with optree.dict_insertion_ordered(True, namespace='demo'):
543+
... flat1, _ = tree_ravel(params1, namespace='demo')
544+
... flat2, _ = tree_ravel(params2, namespace='demo')
545+
>>> flat1
546+
array([1., 2., 3., 4., 5., 6.]) # weight, bias (insertion order of params1)
547+
>>> flat2
548+
array([5., 6., 1., 2., 3., 4.]) # bias, weight (insertion order of params2)
549+
>>> flat1 - flat2 # WRONG! Should be all zeros for equal params
550+
array([-4., -4., 2., 2., 2., 2.])
551+
```
552+
508553
To preserve insertion order during pytree traversal, use [`collections.OrderedDict`](https://docs.python.org/3/library/collections.html#collections.OrderedDict), which considers key order in equality checks:
509554
510555
```python

0 commit comments

Comments
 (0)