We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 3728812 commit 98f9681Copy full SHA for 98f9681
transformer_engine/pytorch/quantized_tensor.py
@@ -274,14 +274,16 @@ def make_empty(
274
shape: Iterable[int],
275
*,
276
dtype: torch.dtype = torch.float32,
277
- device: Optional[torch.device] = None,
+ device: Optional[Union[torch.device, str]] = None,
278
requires_grad: bool = False,
279
pin_memory: bool = False,
280
) -> QuantizedTensor:
281
"""Construct quantized tensor with uninitialized data"""
282
283
if device is None:
284
device = torch.device("cuda")
285
+ # Handle the device passed as string
286
+ device = torch.device(device)
287
result = tex.create_empty_quantized_tensor(
288
self,
289
list(shape),
0 commit comments