-
Notifications
You must be signed in to change notification settings - Fork 133
Open
Labels
Description
What happened?
When measuring an MPS with jax backend, a ValueError: output array is read-only is raised.
The current workaround that I use is either:
- convert the backend to numpy before measure
or - modify file
tensor_1d.py, line 3719 in the definition of "measure" method from MatrixProductState class: replacepi /= pi.sum()bypi = pi / pi.sum().
What did you expect to happen?
No response
Minimal Complete Verifiable Example
import jax
import quimb.tensor as qtn
psi = qtn.MPS_rand_computational_state(4)
print("psi backend:", psi.backend)
psi.measure_(0)
print("numpy measure ok")
psi = qtn.MPS_rand_computational_state(4)
def to_backend(x):
return jax.numpy.asarray(x)
psi.apply_to_arrays(to_backend)
print("psi backend:", psi.backend)
psi.measure_(0)
print("jax measure ok")Relevant log output
psi backend: numpy
numpy measure ok
psi backend: jax
Traceback (most recent call last):
File "PATH/bug_jax_minimal.py", line 21, in <module>
psi.measure_(0)
File "PATH/lib/python3.12/site-packages/quimb/tensor/tensor_1d.py", line 379, in wrapped
return fn(self, *args, info=info, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "PATH/lib/python3.12/site-packages/quimb/tensor/tensor_1d.py", line 3719, in measure
pi /= pi.sum()
ValueError: output array is read-onlyAnything else we need to know?
No response
Environment
ubuntu 24.04.2
python 3.12.3
quimb 1.11.2
jax 0.8.1
Reactions are currently unavailable