Skip to content

Commit 380843e

Browse files
committed
Allow string deviceIds in is_cpu_device()
1 parent 62cea15 commit 380843e

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

sdkit/utils/device_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,10 @@ def ipc_collect():
7878

7979

8080
def is_cpu_device(device) -> bool: # used for cpu offloading etc
81-
"Expects a torch.device as the argument"
81+
"Expects a torch.device or string as the argument"
8282

83-
return device.type in ("cpu", "mps")
83+
device_type = device.split(":")[0] if isinstance(device, str) else device.type
84+
return device_type in ("cpu", "mps")
8485

8586

8687
def has_half_precision_bug(device_name) -> bool:

0 commit comments

Comments
 (0)