Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,57 @@ def RotateOp : EnzymeXLA_Op<"rotate", [Pure, SameOperandsAndResultType]> {
let results = (outs HLO_Tensor:$result);
}

def MultiRotateOp : EnzymeXLA_Op<"multi_rotate", [Pure]> {
let summary = "Produces multiple rotated versions of the input tensor";
let description = [{
MultiRotate operation produces multiple rotated versions of the input tensor.
Given left_amount=L and right_amount=R, it produces L+R+1 results:
- L results for left rotations (from L to 1)
- 1 result for no rotation (amount=0)
- R results for right rotations (from 1 to R)

For example, with left_amount=2 and right_amount=2:
results[0] = rotate left by 2
results[1] = rotate left by 1
results[2] = no rotation (amount=0)
results[3] = rotate right by 1
results[4] = rotate right by 2
}];
let arguments = (ins
HLO_Tensor:$operand,
SI32Attr:$dimension,
SI32Attr:$left_amount,
SI32Attr:$right_amount
);
let hasFolder = 1;
let results = (outs Variadic<HLO_Tensor>:$results);
}

def MultiSliceOp : EnzymeXLA_Op<"multi_slice", [Pure]> {
let summary = "Produces multiple slice versions of the input tensor";
let description = [{
MultiSlice operation produces multiple slice versions of the input tensor.
Given left_amount=L and right_amount=R, it produces L+R+1 results:
- L results for slices shifted left
- 1 result for the center slice
- R results for slices shifted right

The slice parameters (start_indices, limit_indices, strides) define the center slice.
Each left/right result is offset along the specified dimension.
}];
let arguments = (ins
HLO_Tensor:$operand,
I64ArrayAttr:$start_indices,
I64ArrayAttr:$limit_indices,
I64ArrayAttr:$strides,
SI32Attr:$dimension,
SI32Attr:$left_amount,
SI32Attr:$right_amount
);
let hasFolder = 1;
let results = (outs Variadic<HLO_Tensor>:$results);
}

def WrapOp: EnzymeXLA_Op<
"wrap",
[Pure, SameOperandsAndResultElementType,
Expand Down
15 changes: 15 additions & 0 deletions src/enzyme_ad/jax/Dialect/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2487,3 +2487,18 @@ OpFoldResult ExtendOp::fold(FoldAdaptor adaptor) {
}
return nullptr;
}

OpFoldResult MultiRotateOp::fold(FoldAdaptor adaptor) {
// If all amounts are zero and only one result, fold to operand
if (getLeftAmount() == 0 && getRightAmount() == 0) {
// Single result case - return the operand directly
return getOperand();
}
return nullptr;
}

OpFoldResult MultiSliceOp::fold(FoldAdaptor adaptor) {
// Basic folding for multi-slice
// More complex folding can be done in patterns
return nullptr;
}
Loading