Skip to content

ValueError when measuring MPS with jax backend #340

@thibxlv

Description

@thibxlv

What happened?

When measuring an MPS with jax backend, a ValueError: output array is read-only is raised.

The current workaround that I use is either:

  • convert the backend to numpy before measure
    or
  • modify file tensor_1d.py, line 3719 in the definition of "measure" method from MatrixProductState class: replace pi /= pi.sum() by pi = pi / pi.sum().

What did you expect to happen?

No response

Minimal Complete Verifiable Example

import jax
import quimb.tensor as qtn

psi = qtn.MPS_rand_computational_state(4)
print("psi backend:", psi.backend)

psi.measure_(0)
print("numpy measure ok")

psi = qtn.MPS_rand_computational_state(4)

def to_backend(x):
    return jax.numpy.asarray(x)

psi.apply_to_arrays(to_backend)
print("psi backend:", psi.backend)

psi.measure_(0)
print("jax measure ok")

Relevant log output

psi backend: numpy
numpy measure ok
psi backend: jax
Traceback (most recent call last):
  File "PATH/bug_jax_minimal.py", line 21, in <module>
    psi.measure_(0)
  File "PATH/lib/python3.12/site-packages/quimb/tensor/tensor_1d.py", line 379, in wrapped
    return fn(self, *args, info=info, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "PATH/lib/python3.12/site-packages/quimb/tensor/tensor_1d.py", line 3719, in measure
    pi /= pi.sum()
ValueError: output array is read-only

Anything else we need to know?

No response

Environment

ubuntu 24.04.2
python 3.12.3
quimb 1.11.2
jax 0.8.1

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions