Skip to content

Add forward and reverse mode support for numpy.take#765

Open
KryptosAI wants to merge 3 commits intoHIPS:masterfrom
KryptosAI:add-numpy-take-grad-support
Open

Add forward and reverse mode support for numpy.take#765
KryptosAI wants to merge 3 commits intoHIPS:masterfrom
KryptosAI:add-numpy-take-grad-support

Conversation

@KryptosAI
Copy link
Copy Markdown

Summary

autograd.numpy.take is currently wrapped, but gradients with respect to the input array are not defined, so calls like anp.take(x, idx, axis=1) raise NotImplementedError in reverse mode.

This PR adds both reverse-mode and forward-mode support for numpy.take with respect to the input array.

What was missing before

  • autograd.numpy.take had no VJP for the array argument
  • there was no matching JVP registration for forward mode
  • the issue example in support numpy.take #743 failed even though equivalent indexing syntax already worked

Implementation

This keeps the change tight and reuses the existing indexing backprop machinery.

It handles three cases:

  • plain indexing-style scatter for the existing __getitem__ path
  • flattened take(..., axis=None) behavior
  • axis-aware take(..., axis=<int>) behavior

For reverse mode, the gradient is treated as a scatter-add back into the source array via the existing internal untake() sparse helper.

Repeated indices

Repeated indices accumulate gradients additively, matching NumPy gather semantics and the existing scatter logic already used for indexing.

Validation

I added regression coverage for:

  • the issue-style axis case
  • 1D take
  • repeated indices
  • axis=None
  • negative indices
  • out-of-bounds parity with NumPy

I also re-ran the full tests/test_numpy.py suite from a fresh virtualenv with -n 0 and it passed.

Closes #743

@agriyakhetarpal
Copy link
Copy Markdown
Collaborator

Hi @KryptosAI, thanks for this! I've launched the workflows. Looks like they are all green!

Also, I apologise for asking this in this manner in advance, but considering that you have "AI" in your username and the PR description is well-formatted and well-described: we'd like to know whether this PR was generated by AI or an agentic coding tool of any sort, and whether we will be talking to a human on the other end?

@KryptosAI
Copy link
Copy Markdown
Author

Hey @agriyakhetarpal — thanks for running the workflows and for asking directly, totally fair question.

Yes, you're talking to a human! I'm William. I use AI coding assistants (Claude Code, Codex) as part of my workflow, but I wrote the logic, chose the implementation approach, and reviewed everything before submitting. I can explain any part of the code in detail — happy to walk through the take implementation or the gradient derivation if that would help.

The "AI" in KryptosAI is the name of my GitHub org (I work on AI products), not an indicator that a bot is submitting PRs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

support numpy.take

2 participants