Skip to content
Open
37 changes: 36 additions & 1 deletion python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3230,14 +3230,18 @@ def _impl_v18(cls, bb, inputs, attr, params):

use_dynamic_roi = roi_dynamic_vec is not None

# Convert scales to sizes if needed.
# Convert scales to sizes if needed, preserving the original spatial scales so
# the coordinate transformation uses the exact ONNX scale value rather than the
# lossy ratio derived from floor(input * scale) / input.
original_spatial_scales = None
if scales is not None:
if isinstance(scales, relax.Constant):
scales = scales.data.numpy()
elif isinstance(scales, relax.expr.ShapeExpr):
scales = [int(val.value) for val in scales.values]
else:
raise ValueError(f"Type {type(scales)} for scale is currently unsupported.")
original_spatial_scales = [float(s) for s in scales[2:]]
sizes = []

for i, dim in enumerate(x.struct_info.shape):
Expand Down Expand Up @@ -3279,8 +3283,24 @@ def _impl_v18(cls, bb, inputs, attr, params):
cubic_coeff_a,
exclude_outside,
extrapolation_value,
scales=original_spatial_scales,
)
elif ndims == 4:
if original_spatial_scales is not None:
return bb.emit_te(
topi.image.resize2d,
x,
roi_static,
sizes,
"NCHW",
topi_mode,
coord_mode,
rounding_method,
cubic_coeff_a,
exclude_outside,
extrapolation_value,
scales=original_spatial_scales,
)
return relax.op.image.resize2d(
x,
size=relax.ShapeExpr(sizes),
Expand All @@ -3295,6 +3315,21 @@ def _impl_v18(cls, bb, inputs, attr, params):
)
else: # ndims == 5
roi3d = _topi_resize3d_roi_from_onnx_ncdhw_spatial(roi_static)
if original_spatial_scales is not None:
return bb.emit_te(
topi.image.resize3d,
x,
roi3d,
sizes,
"NCDHW",
topi_mode,
coord_mode,
rounding_method,
cubic_coeff_a,
exclude_outside,
extrapolation_value,
scales=original_spatial_scales,
)
return relax.op.image.resize3d(
x,
size=relax.ShapeExpr(sizes),
Expand Down
72 changes: 67 additions & 5 deletions python/tvm/topi/image/resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,13 @@ def get_inx(
start_x=0,
end_x=-1,
use_int_div=False,
scale_x_override=None,
):
"""Infer input x from output x with various coordinate transformation methods"""
scale_x = te.div(image_width.astype("float"), target_width.astype("float"))
if scale_x_override is not None:
scale_x = scale_x_override
else:
scale_x = te.div(image_width.astype("float"), target_width.astype("float"))
if coordinate_transformation_mode == "half_pixel":
in_x = (x + 0.5) * scale_x - 0.5
elif coordinate_transformation_mode == "align_corners":
Expand Down Expand Up @@ -237,6 +241,7 @@ def _resize_1d(
alpha=-0.5,
exclude_outside=0,
out_dtype=None,
scale_x=None,
):
"""Perform resize operation on the data with selected method and options.

Expand Down Expand Up @@ -315,7 +320,15 @@ def _cast_output(value, data_dtype="float32", out_dtype=None):
if boxes is not None:
# TODO(mbrookhart): Find an example of this
raise NotImplementedError("resize1d with image boxes not yet implemented")
in_x = get_inx(x, image_width, target_width, coordinate_transformation_mode, roi[0], roi[1])
in_x = get_inx(
x,
image_width,
target_width,
coordinate_transformation_mode,
roi[0],
roi[1],
scale_x_override=scale_x,
)

if method == "nearest_neighbor":
if rounding_method == "":
Expand Down Expand Up @@ -383,6 +396,7 @@ def resize1d(
extrapolation_value=0.0,
out_dtype=None,
output_shape=None,
scales=None,
):
"""Perform resize operation on the data.

Expand Down Expand Up @@ -472,6 +486,8 @@ def resize1d(
if isinstance(size[i], int):
size[i] = tvm.tirx.IntImm("int32", size[i])

scale_x = (1.0 / scales[0]) if scales is not None else None

def compute_func(*indices):
return _resize_1d(
indices,
Expand All @@ -487,6 +503,7 @@ def compute_func(*indices):
exclude_outside=bicubic_exclude,
extrapolation_value=extrapolation_value,
out_dtype=out_dtype,
scale_x=scale_x,
)

return te.compute(output_shape, compute_func, name="resize", tag=tag.INJECTIVE)
Expand All @@ -510,6 +527,8 @@ def _resize_2d(
alpha=-0.5,
exclude_outside=0,
out_dtype=None,
scale_h=None,
scale_w=None,
):
"""Perform resize operation on the data with selected method and options.

Expand Down Expand Up @@ -618,6 +637,7 @@ def _cast_output(value, data_dtype="float32", out_dtype=None):
roi[1],
roi[3],
width_use_int_div,
scale_x_override=scale_w,
)
in_y = get_inx(
y,
Expand All @@ -627,6 +647,7 @@ def _cast_output(value, data_dtype="float32", out_dtype=None):
roi[0],
roi[2],
height_use_int_div,
scale_x_override=scale_h,
)

if method == "nearest_neighbor":
Expand Down Expand Up @@ -756,6 +777,7 @@ def resize2d(
extrapolation_value=0.0,
out_dtype=None,
output_shape=None,
scales=None,
):
"""Perform resize operation on the data.

Expand Down Expand Up @@ -839,6 +861,9 @@ def resize2d(
if isinstance(size[i], int):
size[i] = tvm.tirx.IntImm("int32", size[i])

scale_h = (1.0 / scales[0]) if scales is not None else None
scale_w = (1.0 / scales[1]) if scales is not None else None

def compute_func(*indices):
return _resize_2d(
indices,
Expand All @@ -856,6 +881,8 @@ def compute_func(*indices):
exclude_outside=bicubic_exclude,
extrapolation_value=extrapolation_value,
out_dtype=out_dtype,
scale_h=scale_h,
scale_w=scale_w,
)

return te.compute(output_shape, compute_func, name="resize", tag=tag.INJECTIVE)
Expand Down Expand Up @@ -976,6 +1003,9 @@ def _resize_3d(
alpha=-0.5,
exclude_outside=0,
out_dtype=None,
scale_d=None,
scale_h=None,
scale_w=None,
):
"""Perform resize operation on the data with selected method and options.

Expand Down Expand Up @@ -1066,9 +1096,33 @@ def _cast_output(value, data_dtype="float32", out_dtype=None):
if boxes is not None:
# TODO(mbrookhart): Find an example of this
raise NotImplementedError("resize1d with image boxes not yet implemented")
in_z = get_inx(z, image_depth, target_depth, coordinate_transformation_mode, roi[2], roi[5])
in_y = get_inx(y, image_height, target_height, coordinate_transformation_mode, roi[1], roi[4])
in_x = get_inx(x, image_width, target_width, coordinate_transformation_mode, roi[0], roi[3])
in_z = get_inx(
z,
image_depth,
target_depth,
coordinate_transformation_mode,
roi[2],
roi[5],
scale_x_override=scale_d,
)
in_y = get_inx(
y,
image_height,
target_height,
coordinate_transformation_mode,
roi[1],
roi[4],
scale_x_override=scale_h,
)
in_x = get_inx(
x,
image_width,
target_width,
coordinate_transformation_mode,
roi[0],
roi[3],
scale_x_override=scale_w,
)

if method == "nearest_neighbor":
if rounding_method == "":
Expand Down Expand Up @@ -1225,6 +1279,7 @@ def resize3d(
extrapolation_value=0.0,
out_dtype=None,
output_shape=None,
scales=None,
):
"""Perform resize operation on the data.

Expand Down Expand Up @@ -1302,6 +1357,10 @@ def resize3d(
if isinstance(size[i], int):
size[i] = tvm.tirx.IntImm("int32", size[i])

scale_d = (1.0 / scales[0]) if scales is not None else None
scale_h = (1.0 / scales[1]) if scales is not None else None
scale_w = (1.0 / scales[2]) if scales is not None else None

def compute_func(*indices):
return _resize_3d(
indices,
Expand All @@ -1321,6 +1380,9 @@ def compute_func(*indices):
exclude_outside=bicubic_exclude,
extrapolation_value=extrapolation_value,
out_dtype=out_dtype,
scale_d=scale_d,
scale_h=scale_h,
scale_w=scale_w,
)

return te.compute(output_shape, compute_func, name="resize", tag=tag.INJECTIVE)
108 changes: 108 additions & 0 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4012,6 +4012,114 @@ def _visit(expr):
assert seen_resize3d


@pytest.mark.parametrize(
"coord_mode, method",
[
("half_pixel", "nearest"),
("pytorch_half_pixel", "nearest"),
("asymmetric", "nearest"),
("half_pixel", "linear"),
],
)
def test_resize_noninteger_scales_2d(coord_mode, method):
"""Non-integer scales must use the original scale in coordinate transformation.

floor(3 * 2.5) = 7, so the recomputed ratio 3/7 = 0.4286 differs from 1/2.5 = 0.4,
causing wrong pixel mapping at boundary positions before the fix.
"""
nearest_mode_kwargs = {}
if method == "nearest":
nearest_mode_kwargs["nearest_mode"] = "round_prefer_floor"
resize_node = helper.make_node(
"Resize",
["X", "", "scales"],
["Y"],
mode=method,
coordinate_transformation_mode=coord_mode,
**nearest_mode_kwargs,
)
graph = helper.make_graph(
[resize_node],
"resize_noninteger_2d",
inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 3, 3])],
initializer=[
helper.make_tensor("scales", TensorProto.FLOAT, [4], [1.0, 1.0, 2.5, 2.5])
],
outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 7, 7])],
)
check_correctness(helper.make_model(graph), opset=18)


def test_resize_noninteger_scales_1d():
resize_node = helper.make_node(
"Resize",
["X", "", "scales"],
["Y"],
mode="nearest",
coordinate_transformation_mode="half_pixel",
nearest_mode="round_prefer_floor",
)
graph = helper.make_graph(
[resize_node],
"resize_noninteger_1d",
inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 5])],
initializer=[helper.make_tensor("scales", TensorProto.FLOAT, [3], [1.0, 1.0, 1.5])],
outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 7])],
)
check_correctness(helper.make_model(graph), opset=18)


def test_resize_noninteger_scales_3d():
resize_node = helper.make_node(
"Resize",
["X", "", "scales"],
["Y"],
mode="nearest",
coordinate_transformation_mode="asymmetric",
nearest_mode="floor",
)
graph = helper.make_graph(
[resize_node],
"resize_noninteger_3d",
inputs=[
helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 3, 3, 3])
],
initializer=[
helper.make_tensor("scales", TensorProto.FLOAT, [5], [1.0, 1.0, 1.5, 1.5, 1.5])
],
outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 4, 4, 4])],
)
check_correctness(helper.make_model(graph), opset=18)


@pytest.mark.parametrize(
"input_shape,scales,output_shape",
[
([1, 1, 4, 4], [1.0, 1.0, 2.0, 2.0], [1, 1, 8, 8]),
([1, 1, 3, 3], [1.0, 1.0, 3.0, 3.0], [1, 1, 9, 9]),
],
)
def test_resize_integer_scales_regression(input_shape, scales, output_shape):
resize_node = helper.make_node(
"Resize",
["X", "", "scales"],
["Y"],
mode="nearest",
coordinate_transformation_mode="half_pixel",
nearest_mode="round_prefer_floor",
)
graph = helper.make_graph(
[resize_node],
"resize_integer_scales",
inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
initializer=[
helper.make_tensor("scales", TensorProto.FLOAT, [len(scales)], scales)
],
outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, output_shape)],
)
check_correctness(helper.make_model(graph), opset=18)


def test_einsum():
eqn = "ij->i"
einsum_node = helper.make_node("Einsum", ["x"], ["y"], equation=eqn)
Expand Down
Loading