Skip to content

Commit 5ea5fcb

Browse files
Combine bits via xor when bitcasting from larger to smaller type (#734)
* Combine bits via xor when bitcasting from larger to smaller type Previously, when casting e.g. a float64 to an int32, numbers that were close in float64 could be mapped to identical int32's. Since these int's are used as keys to generate random sequences, this is problematic, as it results in identical noise being generated in subsequent timesteps. This commit fixes this by not throwing away bits when the input type is larger than the requested output type. Instead, the larger number is bitcast to multiple values in the smaller type, which are then combined using xor. * Add assertion to check for shape and dtype * Add pytest for bitcasting to smaller type This checks the intended behaviour of mapping nearby numbers to distinct values when downcasting to a smaller dtype.
1 parent b3a1885 commit 5ea5fcb

File tree

2 files changed

+29
-15
lines changed

2 files changed

+29
-15
lines changed

diffrax/_misc.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import Callable
2-
from typing import Any, cast
2+
from typing import cast
33

44
import jax
55
import jax.core
@@ -13,22 +13,17 @@
1313
from ._custom_types import BoolScalarLike, RealScalarLike
1414

1515

16-
_itemsize_kind_type: dict[tuple[int, str], Any] = {
17-
(1, "i"): jnp.int8,
18-
(2, "i"): jnp.int16,
19-
(4, "i"): jnp.int32,
20-
(8, "i"): jnp.int64,
21-
(2, "f"): jnp.float16,
22-
(4, "f"): jnp.float32,
23-
(8, "f"): jnp.float64,
24-
}
25-
26-
2716
def force_bitcast_convert_type(val, new_type):
2817
val = jnp.asarray(val)
29-
intermediate_type = _itemsize_kind_type[new_type.dtype.itemsize, val.dtype.kind]
30-
val = val.astype(intermediate_type)
31-
return lax.bitcast_convert_type(val, new_type)
18+
result = lax.bitcast_convert_type(val, new_type)
19+
20+
# If downcasting (larger -> smaller type), bitcast returns multiple values.
21+
# Combine them via XOR to ensure nearby input values map to different outputs.
22+
if result.shape != val.shape:
23+
result = jnp.bitwise_xor.reduce(result, axis=-1)
24+
assert val.shape == result.shape
25+
assert result.dtype == new_type
26+
return result
3227

3328

3429
def _fill_forward(

test/test_misc.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,22 @@ def test_fill_forward():
99
out_ = jnp.array([jnp.nan, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0])
1010
fill_in = diffrax._misc.fill_forward(in_[:, None])
1111
assert tree_allclose(fill_in, out_[:, None], equal_nan=True)
12+
13+
14+
def test_force_bitcast_convert_type():
15+
val_1 = jnp.float64(1e6)
16+
val_2 = jnp.float64(1e6 + 1e-4)
17+
18+
# Val_1 and val_2 are different as float64,
19+
# but would be the same if naively downcast to float32.
20+
assert val_1 != val_2
21+
assert val_1.astype(jnp.int32) == val_2.astype(jnp.int32)
22+
23+
val_1_cast = diffrax._misc.force_bitcast_convert_type(val_1, jnp.int32)
24+
val_2_cast = diffrax._misc.force_bitcast_convert_type(val_2, jnp.int32)
25+
26+
assert val_1_cast.dtype == jnp.int32
27+
assert val_2_cast.dtype == jnp.int32
28+
29+
# Bitcasted values should be different in the smaller type
30+
assert val_1_cast != val_2_cast

0 commit comments

Comments
 (0)