Skip to content

Fix sort/partition gradients for multidim arrays#758

Open
imperatormk wants to merge 2 commits intoHIPS:masterfrom
imperatormk:fix-sort-partition-jvp-multidim
Open

Fix sort/partition gradients for multidim arrays#758
imperatormk wants to merge 2 commits intoHIPS:masterfrom
imperatormk:fix-sort-partition-jvp-multidim

Conversation

@imperatormk
Copy link

Both JVP and VJP were broken for arrays with ndim > 1. The JVP used g[sort_perm] which indexes axis 0 regardless of the sort axis, returning wrong shapes. The VJP just raised NotImplementedError.

Existing tests only covered 1D so this never came up.

@mrityunjai01
Copy link

mrityunjai01 commented Feb 17, 2026

This is an LLM generated commit created by an LLM agent?

@imperatormk
Copy link
Author

Nope just LLM code by an LLM operated by a human!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants