-
Notifications
You must be signed in to change notification settings - Fork 32
Replace deprecated jax context manager and JIT-traced non-array arguments #260
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
dfm
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!! Small comment inline, otherwise looks good to me!
src/tinygp/test_utils.py
Outdated
| ) | ||
|
|
||
|
|
||
| def _as_context_manager(obj): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is more complicated than it needs to be. Let's just switch to using jax.enable_x64 everywhere where we were using the experimental version!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. If you are okay with dropping Python 3.10 I will make that change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's do it - thanks!
Hi, thank you for tinygp.
This PR attempts to replace the
jax.experimental.enable_x64context manager which is deprecated injax>=0.9.0.While testing this change on current GitHub runners some other JAX-related issues surfaced:
See details
As far as I can tell, the problem was the None attribute of the
Meanclass when initiated with a callable. This PR should fix this by setting itsvalueattribute to a valid (dummy) JAX array.