Skip to content

Conversation

@Qazalbash
Copy link
Contributor

@Qazalbash Qazalbash commented Oct 21, 2025

This PR contains the resolution of mypy errors passed by #2032, in the numpyro.distributions.constraints module.

I have tried to replicate the same solution proposed by @fehiepsi in #2066, i.e., the use of generics (see from 69c1ed5 till 1d6b24d).

I have slightly modified the logic. Some notes on them,

  1. ArrayLike contains complex and there is no partial order over complex numbers. MyPy was throwing errors for >, <, <=, and >= operators. They have been replaced with the equivalent jax.numpy function.
  2. There is no mod operation between arrays and integers; it has been replaced with jax.numpy.mod.
  3. Bitwise operations have been replaced with jax.numpy.logical_and and jax.numpy.logical_or.
  4. The type of the argument in the __eq__ method has been changed to object because the method can take any type of object; its implementation is to classify if the objects are the same or not. I have added if not isinstance(other, ...): return False statement at some places due to MyPy's errors.

I tackled with following problems that, in my understanding, require some discussion,

  1. event_dim and is_discrete are read-only properties that are modified by some constraints; therefore, I made them private attributes and introduced getter and setter methods for each.
  2. __eq__ method expects return type to be a bool, but we are returning arrays of booleans. I have marked them to be ignored by MyPy.
  3. We can not ducktype the constraint object with ConstraintT at the end of the module, because the ConstraintT object expects a NumLike object, but some constraints support only NonScalarArray. This issue is solvable via a generic typing protocol (see changes of c12b514). It can also be seen with TransformT and subclasses of Transform, for statement transform_obj: TransformT = TransformClass(...), MyPy will throw an error, if TransformClass uses anything other than NumLike. This issue is also addressable via a generic typing protocol.
  4. jax>=0.7.2 has introduced TypedNdArray to represent constants in jaxpr (ref Include Typed<type> in ArrayLike jax-ml/jax#31989, Add Typed... types to ArrayLike. jax-ml/jax#32227). It is also a part of ArrayLike type, and has no reshape method.

These are all the major outlines of this PR. I will update the description if I recall any.

@Qazalbash
Copy link
Contributor Author

@fehiepsi, can you look into these changes?


@juanitorduz you helped us a lot in solving #2066, you might be interested in these changes too.

@juanitorduz
Copy link
Collaborator

This is a tricky one but there is great progress :) I created a pull request to your branch @Qazalbash with a potential solution Qazalbash#3 . MyPy seems happy about it, but please see if make sense for you

@Qazalbash
Copy link
Contributor Author

Only errors left here are coming from the statments constrain_obj: ConstraintT = ConstraintClass(...), when ConstraintClass uses NonScalarArray type. Because ConstraintT has NumLike and expects similar from ConstraintClass. Same problem can be seen in TransformT.

I tried to address this issue by making typing protocols generic. I later reverted it. They are available at c12b514.

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Qazalbash! Sorry for the delay! I thought the PR was WIP.

corr_matrix: ConstraintT = _CorrMatrix()
dependent: ConstraintT = _Dependent()

boolean: Union[Constraint, ConstraintT] = _Boolean()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why Union is required. It seems confusing to me. Actually I feel I don't understand the typing stuff: why Constraint is not a ConstraintT type, when to use Constraint when to use ConstraintT. Could you clarify?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is TransformT expects a ConstraintT type. ConstraintT uses NumLike type. Some constraints are typed with NonScalarArray, which is not compatible with ConstraintT (violating the Liskov substitution principle).

I tried to address this issue, see the last paragraph of this comment.

@Qazalbash
Copy link
Contributor Author

Hi @fehiepsi, I am a little busy with my grad school applications, can I update you after Dec 15?

@fehiepsi
Copy link
Member

fehiepsi commented Dec 1, 2025

Absolutely, please take your time!

@fehiepsi
Copy link
Member

@Qazalbash I just pushed some changes to remove the dependency on ConstraintT and TransformT. I found using both ConstraintT and Constraint at the same time causes some confusion. So far mypy seems to be fine with the changes locally. Please let me know if you have other opinion.

@Qazalbash
Copy link
Contributor Author

Qazalbash commented Jan 20, 2026

@fehiepsi, thanks for looking into it. I am testing it.

Apparently, there was a doctest failing on the previous commits, but not on the recent one you pushed.

@Qazalbash
Copy link
Contributor Author

Changes LGTM, thanks @fehiepsi.

@fehiepsi fehiepsi merged commit 4fd6c73 into pyro-ppl:master Jan 20, 2026
8 checks passed
@juanitorduz
Copy link
Collaborator

yay! great teamwork!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants