Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions python/paddle/nn/functional/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
from paddle import _C_ops, base, in_dynamic_mode
from paddle.static.nn.control_flow import Assert
from paddle.utils import deprecated
from paddle.utils.decorator_utils import param_one_alias, param_two_alias
from paddle.utils.decorator_utils import (
legacy_reduction_decorator,
param_one_alias,
param_two_alias,
)

from ...base.data_feeder import check_type, check_variable_and_dtype
from ...base.framework import (
Expand Down Expand Up @@ -610,7 +614,7 @@ def edit_distance(
return edit_distance_out, sequence_num


@param_one_alias(["label", "target"])
@legacy_reduction_decorator
def binary_cross_entropy(
input: Tensor,
label: Tensor,
Expand Down Expand Up @@ -682,8 +686,8 @@ def binary_cross_entropy(
>>> print(output)
Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
0.65537095)

"""

if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in binary_cross_entropy should be 'sum', "
Expand Down Expand Up @@ -1475,7 +1479,7 @@ def l1_loss(
return paddle.abs(paddle.subtract(x=input, y=label, name=name))


@param_one_alias(["label", "target"])
@legacy_reduction_decorator
def nll_loss(
input: Tensor,
label: Tensor,
Expand Down Expand Up @@ -1536,8 +1540,8 @@ def nll_loss(
>>> print(result)
Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
1.07202101)

"""

if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in nll_loss should be 'sum', 'mean' or "
Expand Down Expand Up @@ -1875,7 +1879,7 @@ def kl_div(
return loss


@param_one_alias(["label", "target"])
@legacy_reduction_decorator
def mse_loss(
input: Tensor,
label: Tensor,
Expand Down Expand Up @@ -1928,7 +1932,6 @@ def mse_loss(
>>> print(output)
Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
0.04000002)

"""

if reduction not in ['sum', 'mean', 'none']:
Expand Down Expand Up @@ -4179,7 +4182,7 @@ def triplet_margin_loss(
return loss


@param_one_alias(["label", "target"])
@legacy_reduction_decorator
def multi_margin_loss(
input: Tensor,
label: Tensor,
Expand Down Expand Up @@ -4222,7 +4225,6 @@ def multi_margin_loss(
If given, has to be a Tensor of shape (C,) and the data type is float32, float64.
Default is ``'None'`` .


reduction (str, optional):Indicate how to calculate the loss by batch_size.
the candidates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'none'``, the unreduced loss is returned;
Expand All @@ -4248,8 +4250,8 @@ def multi_margin_loss(
>>> print(loss)
Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True,
[0. , 0.66666663, 0. ])

"""

if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"'reduction' in 'multi_margin_loss' should be 'sum', 'mean' or 'none', "
Expand Down
206 changes: 148 additions & 58 deletions python/paddle/utils/decorator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,8 +683,13 @@ def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
_SA0_RD1 = {'size_average': 0, 'reduce': 1}
_SA1_RD2 = {'size_average': 1, 'reduce': 2}
_SA1_RD3 = {'size_average': 1, 'reduce': 3}
_SA2_RD3 = {'size_average': 2, 'reduce': 3}
_SA3_RD4 = {'size_average': 3, 'reduce': 4}
_SA3_RD5 = {'size_average': 3, 'reduce': 5}
_SA4_RD5 = {'size_average': 4, 'reduce': 5}
_SA4_RD6 = {'size_average': 4, 'reduce': 6}
_SA5_RD6 = {'size_average': 5, 'reduce': 6}
_SA7_RD8 = {'size_average': 7, 'reduce': 8}
_SA2_RD4 = {'size_average': 2, 'reduce': 4}

LEGACY_POS: dict[str, dict[str, int]] = {
Expand Down Expand Up @@ -715,6 +720,22 @@ def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
'PoissonNLLLoss': _SA2_RD4,
'MultiMarginLoss': _SA3_RD4,
'TripletMarginLoss': _SA4_RD5,
'binary_cross_entropy': _SA3_RD4,
'mse_loss': _SA2_RD3,
'nll_loss': _SA3_RD5,
'multi_margin_loss': _SA5_RD6,
}

LEGACY_SPECIAL_TRIM_PREFIX: dict[str, int] = {
'CrossEntropyLoss': 0,
'KLDivLoss': 0,
}

LEGACY_ALIAS: dict[str, dict[str, str]] = {
'binary_cross_entropy': {'target': 'label'},
'mse_loss': {'target': 'label'},
'multi_margin_loss': {'target': 'label'},
'nll_loss': {'target': 'label'},
}


Expand All @@ -726,61 +747,154 @@ def compute_legacy_reduction(reduce_val, size_average_val):
return 'sum' if size_average_val is False else 'mean'


def get_legacy_reduce_and_size_average(cls_name, args, kwargs):
def warn_legacy_reduction(api_name, reduce_val, size_avg_val):
suggested = compute_legacy_reduction(reduce_val, size_avg_val)
warnings.warn(
"size_average and reduce args will be deprecated, please use "
f"reduction='{suggested}' instead.",
category=Warning,
stacklevel=2,
)
return suggested


def _is_legacy_reduction_value(value):
return value is None or type(value) is bool


def _convert_legacy_alias(api_name, kwargs):
alias_map = LEGACY_ALIAS.get(api_name)
if alias_map is None or not kwargs:
return

for alias, original in alias_map.items():
if alias in kwargs:
if original in kwargs:
raise ValueError(
f"Cannot specify both '{original}' and its alias '{alias}'"
)
kwargs[original] = kwargs.pop(alias)


def _should_skip_positional_reduce(api_name, args):
trimmed_args = args[LEGACY_SPECIAL_TRIM_PREFIX.get(api_name, 0) :]
if api_name in {'CrossEntropyLoss', 'cross_entropy'}:
return len(trimmed_args) > 2 and trimmed_args[2] in {
'mean',
'sum',
'none',
}
if api_name in {'KLDivLoss', 'kl_div'}:
return len(trimmed_args) > 0 and trimmed_args[0] in {
'mean',
'sum',
'none',
'batchmean',
}
return False


def get_legacy_reduce_and_size_average(api_name, args, kwargs, special=False):
reduce_val = ''
size_avg_val = ''
pos = LEGACY_POS.get(cls_name)
idx = pos.get('size_average')
pos = LEGACY_POS.get(api_name)
if pos is None:
return reduce_val, size_avg_val, args

new_args = list(args)
pop_indices = []

size_avg_idx = pos.get('size_average')
if 'size_average' in kwargs:
size_avg_val = kwargs.pop('size_average')
elif len(args) > idx:
v = args[idx]
if type(v) is bool:
size_avg_val = v
idx = pos.get('reduce')
if not _is_legacy_reduction_value(size_avg_val):
raise TypeError(
f"{api_name}() expected 'size_average' to be bool or None, "
f"but got {type(size_avg_val).__name__}."
)
elif (
len(new_args) > size_avg_idx
and _is_legacy_reduction_value(new_args[size_avg_idx])
):
size_avg_val = new_args[size_avg_idx]
pop_indices.append(size_avg_idx)

reduce_idx = pos.get('reduce')
if 'reduce' in kwargs:
reduce_val = kwargs.pop('reduce')
elif len(args) > idx:
v = args[idx]
if type(v) is bool:
reduce_val = v
return reduce_val, size_avg_val
if not _is_legacy_reduction_value(reduce_val):
raise TypeError(
f"{api_name}() expected 'reduce' to be bool or None, but "
f"got {type(reduce_val).__name__}."
)
elif (
len(new_args) > reduce_idx
and (not special or not _should_skip_positional_reduce(api_name, new_args))
and _is_legacy_reduction_value(new_args[reduce_idx])
):
reduce_val = new_args[reduce_idx]
pop_indices.append(reduce_idx)

for idx in sorted(pop_indices, reverse=True):
del new_args[idx]

def raise_deprecated_error(cls_name, reduce_val, size_avg_val):
suggested = compute_legacy_reduction(reduce_val, size_avg_val)
reduce_val = None if reduce_val == '' else reduce_val
size_avg_val = None if size_avg_val == '' else size_avg_val
raise ValueError(
f"[Deprecated] '{cls_name}' no longer supports 'reduce' or 'size_average'."
f"\nDetected: reduce={reduce_val}, size_average={size_avg_val}"
f"\nPlease use: reduction='{suggested}' instead."
)
return reduce_val, size_avg_val, tuple(new_args)


def legacy_reduction_decorator(
init_func: Callable[_InputT, _RetT],
def _legacy_reduction_decorator_impl(
init_func: Callable[_InputT, _RetT], special: bool
) -> Callable[_InputT, _RetT]:
"""
Function decorator for __init__: intercept deprecated 'reduce' and 'size_average'.
"""
sig = inspect.signature(init_func)
param_names = list(sig.parameters)
is_method = len(param_names) > 0 and param_names[0] == 'self'

@functools.wraps(init_func)
def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
Copy link
Copy Markdown
Contributor

@zhwesky2010 zhwesky2010 Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要改这么多,人工开发找到最简最优的方案写法

直接把这里的raise_deprecated_error展开,并设置kwarg['reduce']=suggested 就可以了吧?

Image

# avoid subclass calling parent class init, causing cls_name to be inaccurate
cls_name = init_func.__qualname__.split(".")[0]
reduce_val, size_avg_val = get_legacy_reduce_and_size_average(
cls_name, args[1:], kwargs
api_name = init_func.__qualname__.split(".")[0]
_convert_legacy_alias(api_name, kwargs)

use_args = args[1:] if is_method else args
reduce_val, size_avg_val, use_args = get_legacy_reduce_and_size_average(
api_name, use_args, kwargs, special
)
if reduce_val != '' or size_avg_val != '':
raise_deprecated_error(cls_name, reduce_val, size_avg_val)
bound_kwargs = {}
public_param_names = param_names[1:] if is_method else param_names
if len(use_args) > len(public_param_names):
raise TypeError(
f"{api_name}() takes {len(public_param_names)} positional "
f"arguments but {len(use_args)} were given"
)
for name, value in zip(public_param_names, use_args):
bound_kwargs[name] = value
for name, value in kwargs.items():
if name in bound_kwargs:
raise TypeError(
f"{api_name}() got multiple values for argument '{name}'"
)
bound_kwargs[name] = value
bound_kwargs['reduction'] = warn_legacy_reduction(
api_name, reduce_val, size_avg_val
)
if is_method:
return init_func(args[0], **bound_kwargs)
return init_func(**bound_kwargs)

return init_func(*args, **kwargs)

wrapper.__signature__ = inspect.signature(init_func)
wrapper.__signature__ = sig
return wrapper


def legacy_reduction_decorator(
init_func: Callable[_InputT, _RetT],
) -> Callable[_InputT, _RetT]:
"""
Function decorator for __init__: intercept deprecated 'reduce' and 'size_average'.
"""
return _legacy_reduction_decorator_impl(init_func, special=False)


def legacy_reduction_special_decorator(
init_func: Callable[_InputT, _RetT],
) -> Callable[_InputT, _RetT]:
Expand All @@ -789,31 +903,7 @@ def legacy_reduction_special_decorator(
based on the general legacy_reduction_decorator logic.
"""

@functools.wraps(init_func)
def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
cls_name = init_func.__qualname__.split(".")[0]
use_args = args[1:]
reduce_val, size_avg_val = get_legacy_reduce_and_size_average(
cls_name, use_args, kwargs
)
if reduce_val != '' or size_avg_val != '':
if not (
(
cls_name == 'CrossEntropyLoss'
and len(use_args) > 2
and use_args[2] in {'mean', 'sum', 'none'}
)
or (
cls_name == 'KLDivLoss'
and len(use_args) > 0
and use_args[0] in {'mean', 'sum', 'none', 'batchmean'}
)
):
raise_deprecated_error(cls_name, reduce_val, size_avg_val)
return init_func(*args, **kwargs)

wrapper.__signature__ = inspect.signature(init_func)
return wrapper
return _legacy_reduction_decorator_impl(init_func, special=True)


def index_add_decorator() -> Callable[
Expand Down
Loading