Skip to content

Commit ac7e997

Browse files
authored
BUG: torch.arange: workaround for missing dtype implementations (#405)
* BUG: torch.arange: workaround for missing dtype implementations reviewed at #405
1 parent a88067a commit ac7e997

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -616,8 +616,12 @@ def arange(start: float,
616616
dtype = torch.int64
617617
else:
618618
dtype = torch.float32
619-
return torch.empty(0, dtype=dtype, device=device, **kwargs)
620-
return torch.arange(start, stop, step, dtype=dtype, device=device, **kwargs)
619+
return torch.empty(0, device=device, **kwargs).to(dtype)
620+
try:
621+
return torch.arange(start, stop, step, dtype=dtype, device=device, **kwargs)
622+
# torch 2.7 raises RuntimeError, 2.9 emits NotImplementedError
623+
except (NotImplementedError, RuntimeError):
624+
return torch.arange(start, stop, step, device=device, **kwargs).to(dtype)
621625

622626
# torch.eye does not accept None as a default for the second argument and
623627
# doesn't support off-diagonals (https://github.com/pytorch/pytorch/issues/70910)

0 commit comments

Comments
 (0)