We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 9ae5d33 commit 6be430aCopy full SHA for 6be430a
transformer_engine/pytorch/quantized_tensor.py
@@ -286,14 +286,16 @@ def make_empty(
286
shape: Iterable[int],
287
*,
288
dtype: torch.dtype = torch.float32,
289
- device: Optional[torch.device] = None,
+ device: Optional[Union[torch.device, str]] = None,
290
requires_grad: bool = False,
291
pin_memory: bool = False,
292
) -> QuantizedTensor:
293
"""Construct quantized tensor with uninitialized data"""
294
295
if device is None:
296
device = torch.device("cuda")
297
+ # Handle the device passed as string
298
+ device = torch.device(device)
299
result = tex.create_empty_quantized_tensor(
300
self,
301
list(shape),
0 commit comments