Skip to content

Commit 98f9681

Browse files
committed
Handle the device passed as string
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
1 parent 3728812 commit 98f9681

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

transformer_engine/pytorch/quantized_tensor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,14 +274,16 @@ def make_empty(
274274
shape: Iterable[int],
275275
*,
276276
dtype: torch.dtype = torch.float32,
277-
device: Optional[torch.device] = None,
277+
device: Optional[Union[torch.device, str]] = None,
278278
requires_grad: bool = False,
279279
pin_memory: bool = False,
280280
) -> QuantizedTensor:
281281
"""Construct quantized tensor with uninitialized data"""
282282

283283
if device is None:
284284
device = torch.device("cuda")
285+
# Handle the device passed as string
286+
device = torch.device(device)
285287
result = tex.create_empty_quantized_tensor(
286288
self,
287289
list(shape),

0 commit comments

Comments
 (0)