fix: make numpy_backend.tile() and jax_backend.tile() consistent#2587
fix: make numpy_backend.tile() and jax_backend.tile() consistent#2587ligerlac wants to merge 1 commit intoscikit-hep:mainfrom
Conversation
|
@ligerlac Thanks for the PR. Today I have been clawing myself out of travel related time dependent TODOs, but I can review this on Thursday (2025-05-22). I haven't looked/thought about this yet, but I assume that this isn't something unique to |
|
It's more of a patch. You are right, the problem is not unqiue to last line fails with We could also patch that in the jax backend. But I guess a more elegant solution would be to make sure that each backend is only receiving arguments of the correct type by calling |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2587 +/- ##
==========================================
- Coverage 98.23% 98.18% -0.05%
==========================================
Files 65 65
Lines 4193 4195 +2
Branches 591 592 +1
==========================================
Hits 4119 4119
- Misses 45 46 +1
- Partials 29 30 +1
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
This fixes a bug in the jax_backend.tile() method. Consider the following minimal example:
The last line fails with
TypeError: tile requires ndarray or scalar arguments, got <class 'list'> at position 0.. However, it works fine when using the numpy backend. The problem stems from differences betweennp.tileandjnp.tile:Unlike
jnp.tile,np.tileimplicitly converts the input to the correct type.This PR ensures
tensor_inis ajnp.arrayto make the behaviour ofnumpy_backend.tile()andjax_backend.tile()consistent.