-
Notifications
You must be signed in to change notification settings - Fork 273
fix(gh-2036): MyPy Errors in numpyro.distributions.constraints Module
#2085
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@fehiepsi, can you look into these changes? @juanitorduz you helped us a lot in solving #2066, you might be interested in these changes too. |
|
This is a tricky one but there is great progress :) I created a pull request to your branch @Qazalbash with a potential solution Qazalbash#3 . MyPy seems happy about it, but please see if make sense for you |
…ts in typing modules
…`Constraint` class
…tConstraint` class
…ne imports in typing modules" This reverts commit 78d8e93.
… bitwise operators
…lasses to use `NonScalarArray`
…iscrete` and `event_dim` in `Constraint` class
|
Only errors left here are coming from the statments I tried to address this issue by making typing protocols generic. I later reverted it. They are available at c12b514. |
fehiepsi
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Qazalbash! Sorry for the delay! I thought the PR was WIP.
numpyro/distributions/constraints.py
Outdated
| corr_matrix: ConstraintT = _CorrMatrix() | ||
| dependent: ConstraintT = _Dependent() | ||
|
|
||
| boolean: Union[Constraint, ConstraintT] = _Boolean() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure why Union is required. It seems confusing to me. Actually I feel I don't understand the typing stuff: why Constraint is not a ConstraintT type, when to use Constraint when to use ConstraintT. Could you clarify?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is TransformT expects a ConstraintT type. ConstraintT uses NumLike type. Some constraints are typed with NonScalarArray, which is not compatible with ConstraintT (violating the Liskov substitution principle).
I tried to address this issue, see the last paragraph of this comment.
|
Hi @fehiepsi, I am a little busy with my grad school applications, can I update you after Dec 15? |
|
Absolutely, please take your time! |
|
@Qazalbash I just pushed some changes to remove the dependency on ConstraintT and TransformT. I found using both ConstraintT and Constraint at the same time causes some confusion. So far mypy seems to be fine with the changes locally. Please let me know if you have other opinion. |
|
@fehiepsi, thanks for looking into it. I am testing it. Apparently, there was a doctest failing on the previous commits, but not on the recent one you pushed. |
|
Changes LGTM, thanks @fehiepsi. |
|
yay! great teamwork! |
This PR contains the resolution of mypy errors passed by #2032, in the
numpyro.distributions.constraintsmodule.I have tried to replicate the same solution proposed by @fehiepsi in #2066, i.e., the use of generics (see from 69c1ed5 till 1d6b24d).
I have slightly modified the logic. Some notes on them,
ArrayLikecontainscomplexand there is no partial order over complex numbers. MyPy was throwing errors for>,<,<=, and>=operators. They have been replaced with the equivalentjax.numpyfunction.jax.numpy.mod.jax.numpy.logical_andandjax.numpy.logical_or.__eq__method has been changed toobjectbecause the method can take any type of object; its implementation is to classify if the objects are the same or not. I have addedif not isinstance(other, ...): return Falsestatement at some places due to MyPy's errors.I tackled with following problems that, in my understanding, require some discussion,
event_dimandis_discreteare read-only properties that are modified by some constraints; therefore, I made them private attributes and introduced getter and setter methods for each.__eq__method expects return type to be abool, but we are returning arrays of booleans. I have marked them to be ignored by MyPy.ConstraintTat the end of the module, because theConstraintTobject expects aNumLikeobject, but some constraints support onlyNonScalarArray. This issue is solvable via a generic typing protocol (see changes of c12b514). It can also be seen withTransformTand subclasses ofTransform, for statementtransform_obj: TransformT = TransformClass(...), MyPy will throw an error, ifTransformClassuses anything other thanNumLike. This issue is also addressable via a generic typing protocol.jax>=0.7.2has introducedTypedNdArrayto represent constants in jaxpr (ref Include Typed<type> in ArrayLike jax-ml/jax#31989, Add Typed... types to ArrayLike. jax-ml/jax#32227). It is also a part ofArrayLiketype, and has noreshapemethod.These are all the major outlines of this PR. I will update the description if I recall any.