Skip to content

Commit 6be430a

Browse files
committed
Handle the device passed as string
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
1 parent 9ae5d33 commit 6be430a

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

transformer_engine/pytorch/quantized_tensor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,14 +286,16 @@ def make_empty(
286286
shape: Iterable[int],
287287
*,
288288
dtype: torch.dtype = torch.float32,
289-
device: Optional[torch.device] = None,
289+
device: Optional[Union[torch.device, str]] = None,
290290
requires_grad: bool = False,
291291
pin_memory: bool = False,
292292
) -> QuantizedTensor:
293293
"""Construct quantized tensor with uninitialized data"""
294294

295295
if device is None:
296296
device = torch.device("cuda")
297+
# Handle the device passed as string
298+
device = torch.device(device)
297299
result = tex.create_empty_quantized_tensor(
298300
self,
299301
list(shape),

0 commit comments

Comments
 (0)