We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 62cea15 commit 380843eCopy full SHA for 380843e
sdkit/utils/device_utils.py
@@ -78,9 +78,10 @@ def ipc_collect():
78
79
80
def is_cpu_device(device) -> bool: # used for cpu offloading etc
81
- "Expects a torch.device as the argument"
+ "Expects a torch.device or string as the argument"
82
83
- return device.type in ("cpu", "mps")
+ device_type = device.split(":")[0] if isinstance(device, str) else device.type
84
+ return device_type in ("cpu", "mps")
85
86
87
def has_half_precision_bug(device_name) -> bool:
0 commit comments