Add forward and reverse mode support for numpy.take#765
Add forward and reverse mode support for numpy.take#765KryptosAI wants to merge 3 commits intoHIPS:masterfrom
Conversation
|
Hi @KryptosAI, thanks for this! I've launched the workflows. Looks like they are all green! Also, I apologise for asking this in this manner in advance, but considering that you have "AI" in your username and the PR description is well-formatted and well-described: we'd like to know whether this PR was generated by AI or an agentic coding tool of any sort, and whether we will be talking to a human on the other end? |
|
Hey @agriyakhetarpal — thanks for running the workflows and for asking directly, totally fair question. Yes, you're talking to a human! I'm William. I use AI coding assistants (Claude Code, Codex) as part of my workflow, but I wrote the logic, chose the implementation approach, and reviewed everything before submitting. I can explain any part of the code in detail — happy to walk through the The "AI" in KryptosAI is the name of my GitHub org (I work on AI products), not an indicator that a bot is submitting PRs. |
Summary
autograd.numpy.takeis currently wrapped, but gradients with respect to the input array are not defined, so calls likeanp.take(x, idx, axis=1)raiseNotImplementedErrorin reverse mode.This PR adds both reverse-mode and forward-mode support for
numpy.takewith respect to the input array.What was missing before
autograd.numpy.takehad no VJP for the array argumentImplementation
This keeps the change tight and reuses the existing indexing backprop machinery.
It handles three cases:
__getitem__pathtake(..., axis=None)behaviortake(..., axis=<int>)behaviorFor reverse mode, the gradient is treated as a scatter-add back into the source array via the existing internal
untake()sparse helper.Repeated indices
Repeated indices accumulate gradients additively, matching NumPy gather semantics and the existing scatter logic already used for indexing.
Validation
I added regression coverage for:
axiscasetakeaxis=NoneI also re-ran the full
tests/test_numpy.pysuite from a fresh virtualenv with-n 0and it passed.Closes #743