Skip to content

Commit 9294aae

Browse files
Copilotvfdev-5
andauthored
Update loss.py type hints to Python 3.10+ syntax (#3501)
Modernize type hints in `ignite/metrics/loss.py` to use PEP 604 union syntax and built-in generics, aligning with Python 3.10+ (now minimum supported version). ## Changes - `Union[str, torch.device]` → `str | torch.device` - `Tuple[...]` → `tuple[...]` - `Dict` → `dict` - Improved `update()` signature from `Sequence[Union[torch.Tensor, Dict]]` to `tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, dict]` for better type safety - Removed obsolete typing imports ```python # Before def update(self, output: Sequence[Union[torch.Tensor, Dict]]) -> None: if len(output) == 2: y_pred, y = cast(Tuple[torch.Tensor, torch.Tensor], output) kwargs: Dict = {} # After def update(self, output: tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, dict]) -> None: if len(output) == 2: y_pred, y = cast(tuple[torch.Tensor, torch.Tensor], output) kwargs: dict = {} ``` Part of ongoing effort to modernize typing across the codebase. <!-- START COPILOT CODING AGENT SUFFIX --> <!-- START COPILOT ORIGINAL PROMPT --> <details> <summary>Original prompt</summary> > Work on this issue: #3481 > Pick a python file from ignite/metrics with old typing hints and replace them with the new typing hints. Send a PR with a single changed file </details> <!-- START COPILOT CODING AGENT TIPS --> --- ✨ Let Copilot coding agent [set things up for you](https://github.com/pytorch/ignite/issues/new?title=✨+Set+up+Copilot+instructions&body=Configure%20instructions%20for%20this%20repository%20as%20documented%20in%20%5BBest%20practices%20for%20Copilot%20coding%20agent%20in%20your%20repository%5D%28https://gh.io/copilot-coding-agent-tips%29%2E%0A%0A%3COnboard%20this%20repo%3E&assignees=copilot) — coding agent works faster and does higher quality work when set up for your repo. --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: vfdev-5 <[email protected]>
1 parent c85dce5 commit 9294aae

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

ignite/metrics/loss.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, cast, Dict, Sequence, Tuple, Union
1+
from typing import Callable, cast
22

33
import torch
44

@@ -77,7 +77,7 @@ def __init__(
7777
loss_fn: Callable,
7878
output_transform: Callable = lambda x: x,
7979
batch_size: Callable = len,
80-
device: Union[str, torch.device] = torch.device("cpu"),
80+
device: str | torch.device = torch.device("cpu"),
8181
skip_unrolling: bool = False,
8282
):
8383
super(Loss, self).__init__(output_transform, device=device, skip_unrolling=skip_unrolling)
@@ -90,12 +90,12 @@ def reset(self) -> None:
9090
self._num_examples = 0
9191

9292
@reinit__is_reduced
93-
def update(self, output: Sequence[Union[torch.Tensor, Dict]]) -> None:
93+
def update(self, output: tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, dict]) -> None:
9494
if len(output) == 2:
95-
y_pred, y = cast(Tuple[torch.Tensor, torch.Tensor], output)
96-
kwargs: Dict = {}
95+
y_pred, y = cast(tuple[torch.Tensor, torch.Tensor], output)
96+
kwargs: dict = {}
9797
else:
98-
y_pred, y, kwargs = cast(Tuple[torch.Tensor, torch.Tensor, Dict], output)
98+
y_pred, y, kwargs = cast(tuple[torch.Tensor, torch.Tensor, dict], output)
9999
average_loss = self._loss_fn(y_pred, y, **kwargs).detach()
100100

101101
if len(average_loss.shape) != 0:

0 commit comments

Comments
 (0)