Skip to content

Conversation

Copy link
Contributor

Copilot AI commented Jan 10, 2026

Implements optimization pattern to merge two consecutive dynamic_update_slice operations where the second subsumes the first, eliminating redundant padding and DUS operations.

Pattern

Matches:

%slice = stablehlo.slice %operand [8:12, 9:6136, ...]
%pad1 = stablehlo.pad %slice, cst, low=[0,1,0], high=[0,1,0]
%dus1 = stablehlo.dynamic_update_slice %operand, %pad1, [8,8,0]

%extend = enzymexla.extend %slice {dimension=0, lhs=1, rhs=1}
%pad2 = stablehlo.pad %extend, cst, low=[0,1,0], high=[0,0,0]
%dus2 = stablehlo.dynamic_update_slice %dus1, %pad2, [7,8,0]

Transforms to:

%slice = stablehlo.slice %operand [8:12, 9:6136, ...]
%extend = enzymexla.extend %slice {dimension=0, lhs=1, rhs=1}
%combined_pad = stablehlo.pad %extend, cst, low=[0,2,0], high=[0,1,0]
%dus = stablehlo.dynamic_update_slice %operand, %combined_pad, [7,8,0]

Changes

  • EnzymeHLOOpt.cpp: Added DUSDUSPadPadToDUSPad pattern following structure of DUSDUSToDUSExtend and DUSDUSToDUSPad
  • TransformOps.td: Registered pattern as dusduspadpad_to_duspad
  • Test: Added dusduspadpad_to_duspad.mlir covering the optimization case

Pattern applies when both slices are identical, padding values match, and the second DUS's update region subsumes the first.

Warning

Firewall rules blocked me from connecting to one or more addresses (expand for details)

I tried to connect to the following addresses, but was blocked by firewall rules:

  • releases.bazel.build
    • Triggering command: /usr/local/lib/node_modules/@bazel/bazelisk/bazelisk-linux_amd64 /usr/local/lib/node_modules/@bazel/bazelisk/bazelisk-linux_amd64 build //src/enzyme_ad/jax:enzymexlamlir-opt --compilation_mode=opt (dns block)

If you need me to access, download, or install something from one of these locations, you can either:

Original prompt

This section details on the original issue you should resolve

<issue_title>DUSDUSPadPadToDUSPad</issue_title>
<issue_description>```
%cst_161 = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc)

%c_169 = stablehlo.constant dense<7> : tensor<i32> loc(#loc) 
%c_171 = stablehlo.constant dense<8> : tensor<i32> loc(#loc)
%c_172 = stablehlo.constant dense<0> : tensor<i32> loc(#loc) 

  %503 = stablehlo.slice %iterArg_177 [8:12, 9:6136, 0:12272] : (tensor<20x6144x12272xf64>) -> tensor<4x6127x12272xf64> loc(#loc2288)

  %504 = stablehlo.pad %503, %cst_161, low = [0, 1, 0], high = [0, 1, 0], interior = [0, 0, 0] : (tensor<4x6127x12272xf64>, tensor<f64>) -> tensor<4x6129x12272xf64> loc(#loc2288)
  %505 = stablehlo.dynamic_update_slice %iterArg_177, %504, %c_171, %c_171, %c_172 : (tensor<20x6144x12272xf64>, tensor<4x6129x12272xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<20x6144x12272xf64> loc(#loc2288)

      %506 = stablehlo.slice %505 [8:12, 7:6135, 0:12272] : (tensor<20x6144x12272xf64>) -> tensor<4x6128x12272xf64> loc(#loc2896)
      %509 = stablehlo.slice %505 [8:12, 10:6138, 0:12272] : (tensor<20x6144x12272xf64>) -> tensor<4x6128x12272xf64> loc(#loc2896)

      %2533 = stablehlo.slice %505 [8:12, 6:6135, 0:12272] : (tensor<20x6144x12272xf64>) -> tensor<4x6129x12272xf64> loc(#loc1390)
      %2535 = stablehlo.slice %505 [8:12, 7:6136, 0:12272] : (tensor<20x6144x12272xf64>) -> tensor<4x6129x12272xf64> loc(#loc1390)
      %2539 = stablehlo.slice %505 [8:12, 9:6138, 0:12272] : (tensor<20x6144x12272xf64>) -> tensor<4x6129x12272xf64> loc(#loc1390)

  %512 = "enzymexla.extend"(%503) <{dimension = 0 : i64, lhs = 1 : i64, rhs = 1 : i64}> : (tensor<4x6127x12272xf64>) -> tensor<6x6127x12272xf64> loc(#loc3088)

  %513 = stablehlo.pad %512, %cst_161, low = [0, 1, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<6x6127x12272xf64>, tensor<f64>) -> tensor<6x6128x12272xf64> loc(#loc3088)
  %514 = stablehlo.dynamic_update_slice %505, %513, %c_169, %c_171, %c_172 : (tensor<20x6144x12272xf64>, tensor<6x6128x12272xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<20x6144x12272xf64> loc(#loc3088)

should be written similarly to https://github.com/EnzymeAD/Enzyme-JAX/pull/1911 (in this case here we can eliminate and/or fuse the two pads together, if possible.

</issue_description>

<agent_instructions>Write this similarly to DUSDUSToDUSExtend and DUSDUSToDUSPad</agent_instructions>

## Comments on the Issue (you are @copilot in this section)

<comments>
<comment_new><author>@wsmoses</author><body>
best mwe so far for what generates the weird pad: https://gist.github.com/wsmoses/52b5ca46e796baa376908f0803092528</body></comment_new>
<comment_new><author>@wsmoses</author><body>
%cst_161 = stablehlo.constant dense<0.000000e+00> : tensor<f64> loc(#loc)

%c_169 = stablehlo.constant dense<7> : tensor<i32> loc(#loc) 
%c_171 = stablehlo.constant dense<8> : tensor<i32> loc(#loc)
%c_172 = stablehlo.constant dense<0> : tensor<i32> loc(#loc) 

  %503 = stablehlo.slice %iterArg_177 [8:12, 9:6136, 0:12272] : (tensor<20x6144x12272xf64>) -> tensor<4x6127x12272xf64> loc(#loc2288)

  %504 = stablehlo.pad %503, %cst_161, low = [0, 1, 0], high = [0, 1, 0], interior = [0, 0, 0] : (tensor<4x6127x12272xf64>, tensor<f64>) -> tensor<4x6129x12272xf64> loc(#loc2288)

  // 8:12, 8:6137, 0:12272
  // 8:12, (8:9, 6136:6137), 0:12272
  %505 = stablehlo.dynamic_update_slice %iterArg_177, %504, %c_171, %c_171, %c_172 : (tensor<20x6144x12272xf64>, tensor<4x6129x12272xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<20x6144x12272xf64> loc(#loc2288)

      %506 = stablehlo.slice %505 [8:12, 7:6135, 0:12272] : (tensor<20x6144x12272xf64>) -> tensor<4x6128x12272xf64> loc(#loc2896)
      %509 = stablehlo.slice %505 [8:12, 10:6138, 0:12272] : (tensor<20x6144x12272xf64>) -> tensor<4x6128x12272xf64> loc(#loc2896)

      %2533 = stablehlo.slice %505 [8:12, 6:6135, 0:12272] : (tensor<20x6144x12272xf64>) -> tensor<4x6129x12272xf64> loc(#loc1390)
      %2535 = stablehlo.slice %505 [8:12, 7:6136, 0:12272] : (tensor<20x6144x12272xf64>) -> tensor<4x6129x12272xf64> loc(#loc1390)
      %2539 = stablehlo.slice %505 [8:12, 9:6138, 0:12272] : (tensor<20x6144x12272xf64>) -> tensor<4x6129x12272xf64> loc(#loc1390)

  %512 = "enzymexla.extend"(%503) <{dimension = 0 : i64, lhs = 1 : i64, rhs = 1 : i64}> : (tensor<4x6127x12272xf64>) -> tensor<6x6127x12272xf64> loc(#loc3088)

  %513 = stablehlo.pad %512, %cst_161, low = [0, 1, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<6x6127x12272xf64>, tensor<f64>) -> tensor<6x6128x12272xf64> loc(#loc3088)

  // 7:13, 8:6136, 0:12272
  %514 = stablehlo.dynamic_update_slice %505, %513, %c_169, %c_171, %c_172 : (tensor<20x6144x12272xf64>, tensor<6x6128x12272xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<
  20x6144x12272xf64> loc(#loc3088)
the latte...

</details>



<!-- START COPILOT CODING AGENT SUFFIX -->

- Fixes EnzymeAD/Enzyme-JAX#1919

<!-- START COPILOT CODING AGENT TIPS -->
---

💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more [Copilot coding agent tips](https://gh.io/copilot-coding-agent-tips) in the docs.

Copilot AI changed the title [WIP] Fix DUSDUSPadPadToDUSPad issue in tensor operations Add DUSDUSPadPadToDUSPad optimization to merge consecutive padded dynamic update slices Jan 10, 2026
Copilot AI requested a review from wsmoses January 10, 2026 13:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants