From 17f63f45200ed6a21ea759d06c77377f5f004181 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Sat, 13 Jun 2026 14:09:48 +0800 Subject: [PATCH 1/4] [Relax][ONNX] Preserve NaN in ReduceMax to align with ONNX Runtime --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 25 ++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 3d9dfba9a16b..72e771017f17 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -3888,6 +3888,25 @@ def _impl_v23(cls, bb, inputs, attr, params): return output +def _reduce_min_max_preserve_nan(reduce_op, data, axes, keepdims): + """Apply a min/max reduction while preserving NaN to match ONNX Runtime. + + relax.op.max/min legalize to a max/min fold implemented as select(x > y, x, y) with an + ordered float comparison, so NaN propagation depends on the fold position (a later non-NaN + element silently overwrites an earlier NaN). ONNX Runtime instead yields NaN whenever any + reduced element is NaN. For floating point inputs, detect NaN along the reduced axes and + force the output to NaN to match that semantics. + """ + y = reduce_op(data, axes, keepdims) + dtype = data.struct_info.dtype if isinstance(data.struct_info, relax.TensorStructInfo) else None + if dtype is None or not _relax_dtype_is_floating_point(dtype): + return y + nan_count = relax.op.sum(relax.op.astype(relax.op.isnan(data), dtype), axes, keepdims) + has_nan = relax.op.greater(nan_count, relax.const(0, dtype)) + nan_filled = relax.op.full_like(y, relax.const(float("nan"), dtype)) + return relax.op.where(has_nan, nan_filled, y) + + class ReduceMax(OnnxOpConverter): """Converts an onnx ReduceMax node into an equivalent Relax expression.""" @@ -3896,7 +3915,7 @@ def _impl_v11(cls, bb, inputs, attr, params): data = inputs[0] axes = attr.get("axes", None) keepdims = attr.get("keepdims", 1) - return relax.op.max(data, axes, keepdims) + return _reduce_min_max_preserve_nan(relax.op.max, data, axes, keepdims) @classmethod def _impl_v18(cls, bb, inputs, attr, params): @@ -3913,13 +3932,13 @@ def _impl_v18(cls, bb, inputs, attr, params): # If axes is empty and noop_with_empty_axes is False, reduce all dims if not axes and not noop_with_empty_axes: - return relax.op.max(data, None, keepdims) + return _reduce_min_max_preserve_nan(relax.op.max, data, None, keepdims) # If axes is empty and noop_with_empty_axes is True, return input unchanged elif not axes and noop_with_empty_axes: return data # Otherwise reduce over specified axes else: - return relax.op.max(data, axes, keepdims) + return _reduce_min_max_preserve_nan(relax.op.max, data, axes, keepdims) class ReduceMin(OnnxOpConverter): From 39d508a5c868ba6a2ea8bcb0aa24812e15101104 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Sat, 13 Jun 2026 14:11:12 +0800 Subject: [PATCH 2/4] [Relax][ONNX] Preserve NaN in ReduceMin to align with ONNX Runtime --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 72e771017f17..6a95acc97981 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -3949,7 +3949,7 @@ def _impl_v11(cls, bb, inputs, attr, params): data = inputs[0] axes = attr.get("axes", None) keepdims = attr.get("keepdims", 1) - return relax.op.min(data, axes, keepdims) + return _reduce_min_max_preserve_nan(relax.op.min, data, axes, keepdims) @classmethod def _impl_v18(cls, bb, inputs, attr, params): @@ -3966,13 +3966,13 @@ def _impl_v18(cls, bb, inputs, attr, params): # If axes is empty and noop_with_empty_axes is False, reduce all dims if not axes and not noop_with_empty_axes: - return relax.op.min(data, None, keepdims) + return _reduce_min_max_preserve_nan(relax.op.min, data, None, keepdims) # If axes is empty and noop_with_empty_axes is True, return input unchanged elif not axes and noop_with_empty_axes: return data # Otherwise reduce over specified axes else: - return relax.op.min(data, axes, keepdims) + return _reduce_min_max_preserve_nan(relax.op.min, data, axes, keepdims) class ReduceSum(OnnxOpConverter): From 18d2e6e532b2cc2b79688efe5a4bc3f83e517fd2 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Sat, 13 Jun 2026 15:59:30 +0800 Subject: [PATCH 3/4] Add test case: test_reduce_min_max_nan_preserve --- tests/python/relax/test_frontend_onnx.py | 38 ++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 8d8c1bc54b9b..b6fd5719df3a 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -804,6 +804,44 @@ def test_sign_nan_preserve(): ) +@pytest.mark.parametrize("op_name", ["ReduceMax", "ReduceMin"]) +@pytest.mark.parametrize( + "x", + [ + np.array([np.nan, 1.0, 2.0], dtype=np.float32), + np.array([2.0, 1.0, np.nan], dtype=np.float32), + np.array([1.0, np.nan, 2.0], dtype=np.float32), + np.array([1.0, 2.0, 3.0], dtype=np.float32), + ], +) +def test_reduce_min_max_nan_preserve(op_name, x): + reduce_node = helper.make_node(op_name, ["x"], ["y"], keepdims=0) + graph = helper.make_graph( + [reduce_node], + "reduce_nan_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x.shape))], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [])], + ) + model = helper.make_model(graph, producer_name="reduce_nan_test") + model.ir_version = 8 + for opset_import in model.opset_import: + if opset_import.domain in ["", "ai.onnx"]: + opset_import.version = 18 + break + + ort_out = onnxruntime.InferenceSession( + model.SerializeToString(), providers=["CPUExecutionProvider"] + ).run([], {"x": x})[0] + + tvm_out = run_in_tvm(model, inputs={"x": x}, opset=18) + out_np = (tvm_out[0] if isinstance(tvm_out, (list, tuple)) else tvm_out).numpy() + + np.testing.assert_array_equal(np.isnan(out_np), np.isnan(ort_out)) + np.testing.assert_allclose( + out_np[~np.isnan(ort_out)], ort_out[~np.isnan(ort_out)], rtol=1e-7, atol=1e-5 + ) + + @pytest.mark.parametrize("op_name", ["Softmax", "LogSoftmax", "Hardmax"]) def test_softmax_family_opset11_default_axis_semantics(op_name: str): verify_unary(op_name, [2, 3, 4], opset=11) From 5e5eb872efcaf95da09b3757f43368db3ee7972f Mon Sep 17 00:00:00 2001 From: cchung100m Date: Sat, 13 Jun 2026 20:15:48 +0800 Subject: [PATCH 4/4] Adopt numpy/IEEE convention instead of ONNX Runtime behavior --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 11 +++++++---- tests/python/relax/test_frontend_onnx.py | 16 +++++++++------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 6a95acc97981..d9fe6efa7211 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -3889,13 +3889,16 @@ def _impl_v23(cls, bb, inputs, attr, params): def _reduce_min_max_preserve_nan(reduce_op, data, axes, keepdims): - """Apply a min/max reduction while preserving NaN to match ONNX Runtime. + """Apply a min/max reduction with well-defined, order-independent NaN propagation. relax.op.max/min legalize to a max/min fold implemented as select(x > y, x, y) with an ordered float comparison, so NaN propagation depends on the fold position (a later non-NaN - element silently overwrites an earlier NaN). ONNX Runtime instead yields NaN whenever any - reduced element is NaN. For floating point inputs, detect NaN along the reduced axes and - force the output to NaN to match that semantics. + element silently overwrites an earlier NaN). ONNX Runtime is also order-independent (it only + yields NaN when the first reduced element is NaN), which is an implementation artifact rather + than a defined semantics and is impractical to replicate portably. We instead adopt the + numpy/IEEE convention used by numpy.max/min and torch.amax/amin: for floating pint inputs, + detect NaN along the reduced axes and force the output to NaN whenever any reduced element is + NaN. """ y = reduce_op(data, axes, keepdims) dtype = data.struct_info.dtype if isinstance(data.struct_info, relax.TensorStructInfo) else None diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index b6fd5719df3a..8e4d66e8e821 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -808,6 +808,10 @@ def test_sign_nan_preserve(): @pytest.mark.parametrize( "x", [ + # NaN in different positions. TVM's max/min fold previously dropped NaN depending on + # position, ONNX Runtime only propagates NaN when it is the first reduced element, which + # is an order-dependent implementation artifact. We instead adopt the well-defined, + # order-independent numpy/IEEE semantics: any NaN in the reduced range yields NaN. np.array([np.nan, 1.0, 2.0], dtype=np.float32), np.array([2.0, 1.0, np.nan], dtype=np.float32), np.array([1.0, np.nan, 2.0], dtype=np.float32), @@ -829,17 +833,15 @@ def test_reduce_min_max_nan_preserve(op_name, x): opset_import.version = 18 break - ort_out = onnxruntime.InferenceSession( - model.SerializeToString(), providers=["CPUExecutionProvider"] - ).run([], {"x": x})[0] + # Reference is numpy (NaN propagates if any element is NaN), not ONNX Runtime. + ref_out = (np.max if op_name == "ReduceMax" else np.min)(x) tvm_out = run_in_tvm(model, inputs={"x": x}, opset=18) out_np = (tvm_out[0] if isinstance(tvm_out, (list, tuple)) else tvm_out).numpy() - np.testing.assert_array_equal(np.isnan(out_np), np.isnan(ort_out)) - np.testing.assert_allclose( - out_np[~np.isnan(ort_out)], ort_out[~np.isnan(ort_out)], rtol=1e-7, atol=1e-5 - ) + np.testing.assert_array_equal(np.isnan(out_np), np.isnan(ref_out)) + if not np.isnan(ref_out): + np.testing.assert_allclose(out_np, ref_out, rtol=1e-7, atol=1e-5) @pytest.mark.parametrize("op_name", ["Softmax", "LogSoftmax", "Hardmax"])