Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3124,6 +3124,53 @@ def Torch_AtenBitwiseOr_TensorOp : Torch_Op<"aten.bitwise_or_.Tensor", [
}];
}

def Torch_AtenBitwiseOrScalarOp : Torch_Op<"aten.bitwise_or.Scalar", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::bitwise_or.Scalar : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$other
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenBitwiseOrScalarOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenBitwiseOrScalarOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenBitwiseOr_ScalarOp : Torch_Op<"aten.bitwise_or_.Scalar", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::bitwise_or_.Scalar : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self,
AnyTorchScalarType:$other
);
let results = (outs
AnyTorchOptionalNonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenBitwiseOr_ScalarOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenBitwiseOr_ScalarOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenBitwiseXorTensorOp : Torch_Op<"aten.bitwise_xor.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -3171,6 +3218,53 @@ def Torch_AtenBitwiseXor_TensorOp : Torch_Op<"aten.bitwise_xor_.Tensor", [
}];
}

def Torch_AtenBitwiseXorScalarOp : Torch_Op<"aten.bitwise_xor.Scalar", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::bitwise_xor.Scalar : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$other
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenBitwiseXorScalarOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenBitwiseXorScalarOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenBitwiseXor_ScalarOp : Torch_Op<"aten.bitwise_xor_.Scalar", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::bitwise_xor_.Scalar : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self,
AnyTorchScalarType:$other
);
let results = (outs
AnyTorchOptionalNonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenBitwiseXor_ScalarOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenBitwiseXor_ScalarOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenBitwiseLeftShiftTensorOp : Torch_Op<"aten.bitwise_left_shift.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
37 changes: 37 additions & 0 deletions lib/Conversion/TorchToStablehlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2067,6 +2067,40 @@ LogicalResult ConvertAtenOp<AtenFmodTensorOp>::matchAndRewrite(
return success();
}

// AtenFmodScalarOp
template <>
LogicalResult ConvertAtenOp<AtenFmodScalarOp>::matchAndRewrite(
AtenFmodScalarOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op->getLoc();
Value lhs = adaptor.getSelf();
Value rhs = adaptor.getOther();

auto resultType =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
lhs = hlo::promoteType(rewriter, op->getLoc(), lhs,
resultType.getElementType());

rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs,
resultType.getElementType());
rhs = hlo::promoteAndBroadcast(rewriter, rhs, resultType, std::nullopt);

stablehlo::MulOp mul;
auto div = stablehlo::DivOp::create(rewriter, loc, lhs, rhs);
if (isa<mlir::FloatType>(resultType.getElementType())) {

auto sign = stablehlo::SignOp::create(rewriter, loc, div);
auto abs = stablehlo::AbsOp::create(rewriter, loc, div);
auto floor = stablehlo::FloorOp::create(rewriter, loc, abs);
auto trunc = stablehlo::MulOp::create(rewriter, loc, sign, floor);
mul = stablehlo::MulOp::create(rewriter, loc, trunc, rhs);
} else {
mul = stablehlo::MulOp::create(rewriter, loc, div, rhs);
}
rewriter.replaceOpWithNewOp<stablehlo::SubtractOp>(op, lhs, mul);
return success();
}

// AtenBitwiseLeftShiftTensorOp
template <>
LogicalResult ConvertAtenOp<AtenBitwiseLeftShiftTensorOp>::matchAndRewrite(
Expand Down Expand Up @@ -2311,7 +2345,9 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_BINARY_LOGICAL_PATTERN(AtenLogicalOrOp, chlo::BroadcastOrOp);
INSERT_BINARY_LOGICAL_PATTERN(AtenLogicalAndOp, chlo::BroadcastAndOp);
INSERT_BINARY_LOGICAL_PATTERN(AtenLogicalXorOp, chlo::BroadcastXorOp);
INSERT_BINARY_LOGICAL_PATTERN(AtenBitwiseOrScalarOp, chlo::BroadcastOrOp);
INSERT_BINARY_LOGICAL_PATTERN(AtenBitwiseAndScalarOp, chlo::BroadcastAndOp);
INSERT_BINARY_LOGICAL_PATTERN(AtenBitwiseXorScalarOp, chlo::BroadcastXorOp);

#undef INSERT_BINARY_LOGICAL_PATTERN

Expand Down Expand Up @@ -2363,6 +2399,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenFillScalarOp);
INSERT_ATENOP_PATTERN(AtenFlipOp);
INSERT_ATENOP_PATTERN(AtenFmodTensorOp);
INSERT_ATENOP_PATTERN(AtenFmodScalarOp);
INSERT_ATENOP_PATTERN(AtenBitwiseLeftShiftTensorOp);
INSERT_ATENOP_PATTERN(AtenBitwiseRightShiftTensorOp);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,9 @@ def emit_with_mutating_variants(key, **kwargs):
"aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::bitwise_and.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::bitwise_or.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::bitwise_or.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::bitwise_xor.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::bitwise_xor.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::bitwise_left_shift.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::bitwise_right_shift.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)",
Expand Down
69 changes: 69 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7147,3 +7147,72 @@ def forward(self, x):
@register_test_case(module_factory=lambda: AtenAsStridedUnknownSizeModule())
def AtenAsStridedUnknownSizeModule_basic(module, tu: TestUtils):
module.forward(torch.randn(12, 13))


# ==============================================================================


class AtenBitwiseOrScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([3, 4], torch.int64, True),
]
)
def forward(self, x):
return torch.ops.aten.bitwise_or(x, 7)


@register_test_case(module_factory=lambda: AtenBitwiseOrScalarModule())
def AtenBitwiseOrScalarModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-10, high=10))


# ==============================================================================


class AtenBitwiseXorScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([3, 4], torch.int64, True),
]
)
def forward(self, x):
return torch.ops.aten.bitwise_xor(x, 11)


@register_test_case(module_factory=lambda: AtenBitwiseXorScalarModule())
def AtenBitwiseXorScalarModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-10, high=10))


# ==============================================================================


class AtenFmodScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([3, 4], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.fmod(x, 7.5)


@register_test_case(module_factory=lambda: AtenFmodScalarModule())
def AtenFmodScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
67 changes: 67 additions & 0 deletions test/Conversion/TorchToStablehlo/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -339,3 +339,70 @@ func.func @torch.aten.tril(%arg0: !torch.vtensor<[2,3,5],f32>, %arg1: !torch.int
%0 = torch.aten.tril %arg0, %arg1:!torch.vtensor<[2,3,5],f32>, !torch.int -> !torch.vtensor<[2,3,5],f32>
return %0 : !torch.vtensor<[2,3,5],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.bitwise_or.Scalar(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,5,6],si32>) -> !torch.vtensor<[4,5,6],si32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5,6],si32> -> tensor<4x5x6xi32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 5
// CHECK: %[[VAL_3:.*]] = arith.constant 5 : i64
// CHECK: %[[VAL_4:.*]] = tensor.from_elements %[[VAL_3]] : tensor<1xi64>
// CHECK: %[[VAL_5:.*]] = stablehlo.convert %[[VAL_4]] : (tensor<1xi64>) -> tensor<1xi32>
// CHECK: %[[VAL_6:.*]] = stablehlo.reshape %[[VAL_5]] : (tensor<1xi32>) -> tensor<i32>
// CHECK: %[[VAL_7:.*]] = chlo.broadcast_or %[[VAL_1]], %[[VAL_6]] : (tensor<4x5x6xi32>, tensor<i32>) -> tensor<4x5x6xi32>
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<4x5x6xi32> -> !torch.vtensor<[4,5,6],si32>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[4,5,6],si32>
// CHECK: }
func.func @torch.aten.bitwise_or.Scalar(%arg0: !torch.vtensor<[4,5,6],si32>) -> !torch.vtensor<[4,5,6],si32> {
%int5 = torch.constant.int 5
%0 = torch.aten.bitwise_or.Scalar %arg0, %int5 : !torch.vtensor<[4,5,6],si32>, !torch.int -> !torch.vtensor<[4,5,6],si32>
return %0 : !torch.vtensor<[4,5,6],si32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.bitwise_xor.Scalar(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,5,6],si32>) -> !torch.vtensor<[4,5,6],si32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,5,6],si32> -> tensor<4x5x6xi32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 8
// CHECK: %[[VAL_3:.*]] = arith.constant 8 : i64
// CHECK: %[[VAL_4:.*]] = tensor.from_elements %[[VAL_3]] : tensor<1xi64>
// CHECK: %[[VAL_5:.*]] = stablehlo.convert %[[VAL_4]] : (tensor<1xi64>) -> tensor<1xi32>
// CHECK: %[[VAL_6:.*]] = stablehlo.reshape %[[VAL_5]] : (tensor<1xi32>) -> tensor<i32>
// CHECK: %[[VAL_7:.*]] = chlo.broadcast_xor %[[VAL_1]], %[[VAL_6]] : (tensor<4x5x6xi32>, tensor<i32>) -> tensor<4x5x6xi32>
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<4x5x6xi32> -> !torch.vtensor<[4,5,6],si32>
// CHECK: return %[[VAL_8]] : !torch.vtensor<[4,5,6],si32>
// CHECK: }
func.func @torch.aten.bitwise_xor.Scalar(%arg0: !torch.vtensor<[4,5,6],si32>) -> !torch.vtensor<[4,5,6],si32> {
%int8 = torch.constant.int 8
%0 = torch.aten.bitwise_xor.Scalar %arg0, %int8 : !torch.vtensor<[4,5,6],si32>, !torch.int -> !torch.vtensor<[4,5,6],si32>
return %0 : !torch.vtensor<[4,5,6],si32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.fmod.Scalar(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,3,2],f32>) -> !torch.vtensor<[4,3,2],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3,2],f32> -> tensor<4x3x2xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 17
// CHECK: %[[VAL_3:.*]] = arith.constant 17 : i64
// CHECK: %[[VAL_4:.*]] = tensor.from_elements %[[VAL_3]] : tensor<1xi64>
// CHECK: %[[VAL_5:.*]] = stablehlo.convert %[[VAL_4]] : (tensor<1xi64>) -> tensor<1xf32>
// CHECK: %[[VAL_6:.*]] = stablehlo.reshape %[[VAL_5]] : (tensor<1xf32>) -> tensor<f32>
// CHECK: %[[VAL_7:.*]] = stablehlo.broadcast_in_dim %[[VAL_6]], dims = [] : (tensor<f32>) -> tensor<4x3x2xf32>
// CHECK: %[[VAL_8:.*]] = stablehlo.divide %[[VAL_1]], %[[VAL_7]] : tensor<4x3x2xf32>
// CHECK: %[[VAL_9:.*]] = stablehlo.sign %[[VAL_8]] : tensor<4x3x2xf32>
// CHECK: %[[VAL_10:.*]] = stablehlo.abs %[[VAL_8]] : tensor<4x3x2xf32>
// CHECK: %[[VAL_11:.*]] = stablehlo.floor %[[VAL_10]] : tensor<4x3x2xf32>
// CHECK: %[[VAL_12:.*]] = stablehlo.multiply %[[VAL_9]], %[[VAL_11]] : tensor<4x3x2xf32>
// CHECK: %[[VAL_13:.*]] = stablehlo.multiply %[[VAL_12]], %[[VAL_7]] : tensor<4x3x2xf32>
// CHECK: %[[VAL_14:.*]] = stablehlo.subtract %[[VAL_1]], %[[VAL_13]] : tensor<4x3x2xf32>
// CHECK: %[[VAL_15:.*]] = torch_c.from_builtin_tensor %[[VAL_14]] : tensor<4x3x2xf32> -> !torch.vtensor<[4,3,2],f32>
// CHECK: return %[[VAL_15]] : !torch.vtensor<[4,3,2],f32>
// CHECK: }
func.func @torch.aten.fmod.Scalar(%arg0: !torch.vtensor<[4,3,2],f32>) -> !torch.vtensor<[4,3,2],f32> {
%int17 = torch.constant.int 17
%0 = torch.aten.fmod.Scalar %arg0, %int17 : !torch.vtensor<[4,3,2],f32>, !torch.int -> !torch.vtensor<[4,3,2],f32>
return %0 : !torch.vtensor<[4,3,2],f32>
}