Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3888,6 +3888,28 @@ def _impl_v23(cls, bb, inputs, attr, params):
return output


def _reduce_min_max_preserve_nan(reduce_op, data, axes, keepdims):
"""Apply a min/max reduction with well-defined, order-independent NaN propagation.

relax.op.max/min legalize to a max/min fold implemented as select(x > y, x, y) with an
ordered float comparison, so NaN propagation depends on the fold position (a later non-NaN
element silently overwrites an earlier NaN). ONNX Runtime is also order-independent (it only
yields NaN when the first reduced element is NaN), which is an implementation artifact rather
than a defined semantics and is impractical to replicate portably. We instead adopt the
numpy/IEEE convention used by numpy.max/min and torch.amax/amin: for floating pint inputs,
detect NaN along the reduced axes and force the output to NaN whenever any reduced element is
NaN.
"""
y = reduce_op(data, axes, keepdims)
dtype = data.struct_info.dtype if isinstance(data.struct_info, relax.TensorStructInfo) else None
if dtype is None or not _relax_dtype_is_floating_point(dtype):
return y
nan_count = relax.op.sum(relax.op.astype(relax.op.isnan(data), dtype), axes, keepdims)
has_nan = relax.op.greater(nan_count, relax.const(0, dtype))
nan_filled = relax.op.full_like(y, relax.const(float("nan"), dtype))
return relax.op.where(has_nan, nan_filled, y)


class ReduceMax(OnnxOpConverter):
"""Converts an onnx ReduceMax node into an equivalent Relax expression."""

Expand All @@ -3896,7 +3918,7 @@ def _impl_v11(cls, bb, inputs, attr, params):
data = inputs[0]
axes = attr.get("axes", None)
keepdims = attr.get("keepdims", 1)
return relax.op.max(data, axes, keepdims)
return _reduce_min_max_preserve_nan(relax.op.max, data, axes, keepdims)

@classmethod
def _impl_v18(cls, bb, inputs, attr, params):
Expand All @@ -3913,13 +3935,13 @@ def _impl_v18(cls, bb, inputs, attr, params):

# If axes is empty and noop_with_empty_axes is False, reduce all dims
if not axes and not noop_with_empty_axes:
return relax.op.max(data, None, keepdims)
return _reduce_min_max_preserve_nan(relax.op.max, data, None, keepdims)
# If axes is empty and noop_with_empty_axes is True, return input unchanged
elif not axes and noop_with_empty_axes:
return data
# Otherwise reduce over specified axes
else:
return relax.op.max(data, axes, keepdims)
return _reduce_min_max_preserve_nan(relax.op.max, data, axes, keepdims)


class ReduceMin(OnnxOpConverter):
Expand All @@ -3930,7 +3952,7 @@ def _impl_v11(cls, bb, inputs, attr, params):
data = inputs[0]
axes = attr.get("axes", None)
keepdims = attr.get("keepdims", 1)
return relax.op.min(data, axes, keepdims)
return _reduce_min_max_preserve_nan(relax.op.min, data, axes, keepdims)

@classmethod
def _impl_v18(cls, bb, inputs, attr, params):
Expand All @@ -3947,13 +3969,13 @@ def _impl_v18(cls, bb, inputs, attr, params):

# If axes is empty and noop_with_empty_axes is False, reduce all dims
if not axes and not noop_with_empty_axes:
return relax.op.min(data, None, keepdims)
return _reduce_min_max_preserve_nan(relax.op.min, data, None, keepdims)
# If axes is empty and noop_with_empty_axes is True, return input unchanged
elif not axes and noop_with_empty_axes:
return data
# Otherwise reduce over specified axes
else:
return relax.op.min(data, axes, keepdims)
return _reduce_min_max_preserve_nan(relax.op.min, data, axes, keepdims)


class ReduceSum(OnnxOpConverter):
Expand Down
40 changes: 40 additions & 0 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,46 @@ def test_sign_nan_preserve():
)


@pytest.mark.parametrize("op_name", ["ReduceMax", "ReduceMin"])
@pytest.mark.parametrize(
"x",
[
# NaN in different positions. TVM's max/min fold previously dropped NaN depending on
# position, ONNX Runtime only propagates NaN when it is the first reduced element, which
# is an order-dependent implementation artifact. We instead adopt the well-defined,
# order-independent numpy/IEEE semantics: any NaN in the reduced range yields NaN.
np.array([np.nan, 1.0, 2.0], dtype=np.float32),
np.array([2.0, 1.0, np.nan], dtype=np.float32),
np.array([1.0, np.nan, 2.0], dtype=np.float32),
np.array([1.0, 2.0, 3.0], dtype=np.float32),
],
)
def test_reduce_min_max_nan_preserve(op_name, x):
reduce_node = helper.make_node(op_name, ["x"], ["y"], keepdims=0)
graph = helper.make_graph(
[reduce_node],
"reduce_nan_test",
inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x.shape))],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [])],
)
model = helper.make_model(graph, producer_name="reduce_nan_test")
model.ir_version = 8
for opset_import in model.opset_import:
if opset_import.domain in ["", "ai.onnx"]:
opset_import.version = 18
break

# Reference is numpy (NaN propagates if any element is NaN), not ONNX Runtime.
ref_out = (np.max if op_name == "ReduceMax" else np.min)(x)

tvm_out = run_in_tvm(model, inputs={"x": x}, opset=18)
out_np = (tvm_out[0] if isinstance(tvm_out, (list, tuple)) else tvm_out).numpy()

np.testing.assert_array_equal(np.isnan(out_np), np.isnan(ref_out))
if not np.isnan(ref_out):
np.testing.assert_allclose(out_np, ref_out, rtol=1e-7, atol=1e-5)


@pytest.mark.parametrize("op_name", ["Softmax", "LogSoftmax", "Hardmax"])
def test_softmax_family_opset11_default_axis_semantics(op_name: str):
verify_unary(op_name, [2, 3, 4], opset=11)
Expand Down
Loading