diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 88fbea7ff0..e8f2335ed3 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -143,6 +143,11 @@ Metrics .. autoclass:: PSNRMetric :members: +`Mean absolute percentage error` +--------------------------------- +.. autoclass:: MAPEMetric + :members: + `Structural similarity index measure` ------------------------------------- .. autoclass:: monai.metrics.regression.SSIMMetric diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index ae20903cfd..78117d9db6 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -28,6 +28,7 @@ from .panoptic_quality import PanopticQualityMetric, compute_panoptic_quality from .regression import ( MAEMetric, + MAPEMetric, MSEMetric, MultiScaleSSIMMetric, PSNRMetric, diff --git a/monai/metrics/regression.py b/monai/metrics/regression.py index 317d624e73..ad5fe73888 100644 --- a/monai/metrics/regression.py +++ b/monai/metrics/regression.py @@ -143,6 +143,39 @@ def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor return compute_mean_error_metrics(y_pred, y, func=self.abs_func) +class MAPEMetric(RegressionMetric): + r"""Compute Mean Absolute Percentage Error between two tensors using function: + + .. math:: + \operatorname {MAPE}\left(Y, \hat{Y}\right) =\frac {100}{n}\sum _{i=1}^{n}\left|\frac{y_i-\hat{y_i}}{y_i}\right|. + + More info: https://en.wikipedia.org/wiki/Mean_absolute_percentage_error + + Input `y_pred` is compared with ground truth `y`. + Both `y_pred` and `y` are expected to be real-valued, where `y_pred` is output from a regression model. + Note: Tackling the undefined error, a tiny epsilon value is added to the denominator part. + + Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`. + + Args: + reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values, + available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. + get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans). + epsilon: float. Defaults to 1e-7. + + """ + + def __init__( + self, reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, epsilon: float = 1e-7 + ) -> None: + super().__init__(reduction=reduction, get_not_nans=get_not_nans) + self.epsilon = epsilon + + def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return compute_mape_metric(y_pred, y, epsilon=self.epsilon) + + class RMSEMetric(RegressionMetric): r"""Compute Root Mean Squared Error between two tensors using function: @@ -220,6 +253,23 @@ def compute_mean_error_metrics(y_pred: torch.Tensor, y: torch.Tensor, func: Call return torch.mean(flt(func(y - y_pred)), dim=-1, keepdim=True) +def compute_mape_metric(y_pred: torch.Tensor, y: torch.Tensor, epsilon: float = 1e-7) -> torch.Tensor: + """ + Compute Mean Absolute Percentage Error. + + Args: + y_pred: predicted values + y: ground truth values + epsilon: small value to avoid division by zero + + Returns: + MAPE value as percentage + """ + flt = partial(torch.flatten, start_dim=1) + percentage_error = torch.abs(y - y_pred) / torch.clamp(torch.abs(y), min=epsilon) * 100.0 + return torch.mean(flt(percentage_error), dim=-1, keepdim=True) + + class KernelType(StrEnum): GAUSSIAN = "gaussian" UNIFORM = "uniform" diff --git a/tests/metrics/test_compute_regression_metrics.py b/tests/metrics/test_compute_regression_metrics.py index c407ab6ba6..6f9d7b6c42 100644 --- a/tests/metrics/test_compute_regression_metrics.py +++ b/tests/metrics/test_compute_regression_metrics.py @@ -17,7 +17,7 @@ import numpy as np import torch -from monai.metrics import MAEMetric, MSEMetric, PSNRMetric, RMSEMetric +from monai.metrics import MAEMetric, MAPEMetric, MSEMetric, PSNRMetric, RMSEMetric from monai.utils import set_determinism @@ -44,6 +44,11 @@ def psnrmetric_np(max_val, y_pred, y): return np.mean(20 * np.log10(max_val) - 10 * np.log10(mse)) +def mapemetric_np(y_pred, y, epsilon=1e-7): + percentage_error = np.abs(y - y_pred) / np.clip(np.abs(y), a_min=epsilon, a_max=None) * 100.0 + return np.mean(flatten(percentage_error)) + + class TestRegressionMetrics(unittest.TestCase): def test_shape_reduction(self): @@ -51,7 +56,7 @@ def test_shape_reduction(self): device = "cuda" if torch.cuda.is_available() else "cpu" # regression metrics to check - metrics = [MSEMetric, MAEMetric, RMSEMetric, partial(PSNRMetric, max_val=1.0)] + metrics = [MSEMetric, MAEMetric, MAPEMetric, RMSEMetric, partial(PSNRMetric, max_val=1.0)] # define variations in batch/base_dims/spatial_dims batch_dims = [1, 2, 4, 16] @@ -94,8 +99,8 @@ def test_compare_numpy(self): device = "cuda" if torch.cuda.is_available() else "cpu" # regression metrics to check + truth metric function in numpy - metrics = [MSEMetric, MAEMetric, RMSEMetric, partial(PSNRMetric, max_val=1.0)] - metrics_np = [msemetric_np, maemetric_np, rmsemetric_np, partial(psnrmetric_np, max_val=1.0)] + metrics = [MSEMetric, MAEMetric, MAPEMetric, RMSEMetric, partial(PSNRMetric, max_val=1.0)] + metrics_np = [msemetric_np, maemetric_np, mapemetric_np, rmsemetric_np, partial(psnrmetric_np, max_val=1.0)] # define variations in batch/base_dims/spatial_dims batch_dims = [1, 2, 4, 16] @@ -117,14 +122,14 @@ def test_compare_numpy(self): out_tensor = mt.aggregate(reduction="mean") out_np = mt_fn_np(y_pred=in_tensor_a.cpu().numpy(), y=in_tensor_b.cpu().numpy()) - np.testing.assert_allclose(out_tensor.cpu().numpy(), out_np, atol=1e-4) + np.testing.assert_allclose(out_tensor.cpu().numpy(), out_np, atol=1e-3, rtol=1e-4) def test_ill_shape(self): set_determinism(seed=123) device = "cuda" if torch.cuda.is_available() else "cpu" # regression metrics to check + truth metric function in numpy - metrics = [MSEMetric, MAEMetric, RMSEMetric, partial(PSNRMetric, max_val=1.0)] + metrics = [MSEMetric, MAEMetric, MAPEMetric, RMSEMetric, partial(PSNRMetric, max_val=1.0)] basedim = 10 # too small shape @@ -143,8 +148,8 @@ def test_ill_shape(self): def test_same_input(self): set_determinism(seed=123) device = "cuda" if torch.cuda.is_available() else "cpu" - metrics = [MSEMetric, MAEMetric, RMSEMetric, partial(PSNRMetric, max_val=1.0)] - results = [0.0, 0.0, 0.0, float("inf")] + metrics = [MSEMetric, MAEMetric, MAPEMetric, RMSEMetric, partial(PSNRMetric, max_val=1.0)] + results = [0.0, 0.0, 0.0, 0.0, float("inf")] # define variations in batch/base_dims/spatial_dims batch_dims = [1, 2, 4, 16] @@ -168,8 +173,8 @@ def test_same_input(self): def test_diff_input(self): set_determinism(seed=123) device = "cuda" if torch.cuda.is_available() else "cpu" - metrics = [MSEMetric, MAEMetric, RMSEMetric, partial(PSNRMetric, max_val=1.0)] - results = [1.0, 1.0, 1.0, 0.0] + metrics = [MSEMetric, MAEMetric, MAPEMetric, RMSEMetric, partial(PSNRMetric, max_val=1.0)] + results = [1.0, 1.0, 100.0, 1.0, 0.0] # define variations in batch/base_dims/spatial_dims batch_dims = [1, 2, 4, 16]