Skip to content

Fix np.mean with list/tuple containing ArrayBox objects#757

Open
tylerflex wants to merge 1 commit intoHIPS:masterfrom
tylerflex:fix-mean-list-input
Open

Fix np.mean with list/tuple containing ArrayBox objects#757
tylerflex wants to merge 1 commit intoHIPS:masterfrom
tylerflex:fix-mean-list-input

Conversation

@tylerflex
Copy link

Summary

Fix autograd.numpy.mean() failing when passed a list/tuple containing ArrayBox objects.

import autograd
import autograd.numpy as np

def f(x):
    return np.mean([x, x+2])

g = autograd.grad(f)
g(0.0)  # Previously raised TypeError, now returns 1.0

Previously raised: TypeError: float() argument must be a string or a real number, not 'ArrayBox'

Root cause

When np.mean() receives a list containing ArrayBox objects, numpy's internal _mean implementation calls ret.dtype.type(result) to cast the output.

Since ArrayBox.dtype returns float64, this becomes np.float64(ArrayBox), which calls float() on the ArrayBox and fails.

Fix

  • Add a mean wrapper that converts list/tuple inputs to arrays using autograd's
    array() function before calling the primitive
  • Update VJP/JVP registrations to use _primitive_mean (the underlying primitive)
  • Add **kwargs to grad_np_mean to accept optional arguments like dtype

ran tests and linting with nox

When np.mean() is called with a list containing ArrayBox objects (e.g.,
np.mean([x, x+2])), numpy's internal implementation calls float() on
the result, which fails for ArrayBox.

This fix adds a wrapper for mean() that converts list/tuple inputs to
arrays first, avoiding the problematic code path in numpy's _mean.

Fixes the error:
TypeError: float() argument must be a string or a real number, not 'ArrayBox'

Co-Authored-By: Claude Opus 4.5 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant