Skip to content

Commit f5cacb3

Browse files
committed
Fix tensor_utils.pack() to correctly pack uint64 scalars.
When packing a given NumPy array into a Tensor message, we use the array's dtype to determine which `Packer` to use. For some types (e.g. NumPy strings) we need to determine the base type rather than the concrete type (i.e. `np.str_`) using the `dtype.type` attribute. This can cause some oddities with integer types, e.g: ``` np.array(2**64-1).dtype == np.uint64 np.array(2**64-1).dtype.type == np.ulonglong np.array(2**64-1 dtype=np.uint64).dtype.type == np.uint64 ``` Therefore we re-cast as a dtype (i.e. `dtype(array.dtype.type)`) before looking up the packer. PiperOrigin-RevId: 410004520 Change-Id: I92cb480bfd0949d021c01d8b41345132cda730dd
1 parent 71076a6 commit f5cacb3

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

dm_env_rpc/v1/tensor_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def unpack(self, proto: TensorOrTensorSpecValue) -> np.ndarray:
143143
}
144144

145145
_TYPE_TO_PACKER = {
146-
packer.np_type: packer for packer in _PACKERS
146+
np.dtype(packer.np_type): packer for packer in _PACKERS
147147
}
148148

149149
_DM_ENV_RPC_DTYPE_TO_NUMPY_DTYPE = {
@@ -216,7 +216,7 @@ def get_packer(np_type: np.dtype) -> Packer:
216216
Raises:
217217
TypeError: If the provided NumPy type has no known packer.
218218
"""
219-
packer = _TYPE_TO_PACKER.get(np_type)
219+
packer = _TYPE_TO_PACKER.get(np.dtype(np_type))
220220
if not packer:
221221
raise TypeError(f'Unknown NumPy type "{np_type}" has no known packer.')
222222
return packer

dm_env_rpc/v1/tensor_utils_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def test_pack_scalar_any_proto(self):
8383
(25, np.int64, 'int64s'),
8484
(25, np.uint32, 'uint32s'),
8585
(25, np.uint64, 'uint64s'),
86+
(2**64-1, np.uint64, 'uint64s'),
8687
(True, np.bool, 'bools'),
8788
(False, np.bool, 'bools'),
8889
('foo', np.str, 'strings'),

0 commit comments

Comments
 (0)