Skip to content

Commit af37655

Browse files
authored
Always call jnp.finfo on array dtypes, not arrays themselves (#2147)
* Always call jnp.finfo on array dtypes, not arrays themselves * Fix instance in safe_normalize * Remove .lock
1 parent 78ebed1 commit af37655

File tree

3 files changed

+5
-3
lines changed

3 files changed

+5
-3
lines changed

numpyro/distributions/continuous.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,9 @@ def sample(
348348
assert is_prng_key(key)
349349
shape = sample_shape + self.batch_shape
350350
samples = random.dirichlet(key, self.concentration, shape=shape)
351-
return jnp.clip(samples, jnp.finfo(samples).tiny, 1 - jnp.finfo(samples).eps)
351+
return jnp.clip(
352+
samples, jnp.finfo(samples.dtype).tiny, 1 - jnp.finfo(samples.dtype).eps
353+
)
352354

353355
@validate_sample
354356
def log_prob(self, value: ArrayLike) -> ArrayLike:

numpyro/distributions/discrete.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1056,7 +1056,7 @@ def entropy(self) -> ArrayLike:
10561056
logq = -jax.nn.softplus(self.logits)
10571057
logp = -jax.nn.softplus(-self.logits)
10581058
p = jax.scipy.special.expit(self.logits)
1059-
p_clip = jnp.clip(p, jnp.finfo(p).tiny)
1059+
p_clip = jnp.clip(p, jnp.finfo(p.dtype).tiny)
10601060
return -(1 - p) * logq / p_clip - logp
10611061

10621062

numpyro/distributions/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -703,7 +703,7 @@ def safe_normalize(x, *, p=2):
703703
assert isinstance(p, (float, int))
704704
assert p >= 0
705705
norm = jnp.linalg.norm(x, p, axis=-1, keepdims=True)
706-
x = x / jnp.clip(norm, jnp.finfo(x).tiny)
706+
x = x / jnp.clip(norm, jnp.finfo(x.dtype).tiny)
707707
# Avoid the singularity.
708708
mask = jnp.all(x == 0, axis=-1, keepdims=True)
709709
x = jnp.where(mask, x.shape[-1] ** (-1 / p), x)

0 commit comments

Comments
 (0)