Repair use of deprecated parameter#1671
Conversation
|
Incidentally, would you reconsider #1? Maintaining downstream shims for all of these functions is becoming increasingly tedious, especially since there appears to be a fairly small upstream change that would eliminate a lot of duplicated workaround code. At least in my case, I’ve ended up maintaining an entire layer of wrappers/shims around Optax to recover this functionality: From reading through the discussion, it seems I’m not the only one who has independently arrived at this kind of solution. |
rdyro
left a comment
There was a problem hiding this comment.
Thanks! Could you add a small tests for this change?
I'm working on a PR for this, but it's hard to make this change while passing all internal tests. This might take some time :( |
Absolutely no problem!! If this is on the horizon, and I'll one day be able to delete my shims, that makes me extremely happy ❤️ By the way, as far as I can tell, my shims do pass your internal tests already--at least from the time I copied those tests.
Happy to, and done! |
| with self.assertRaises(ValueError): | ||
| dp_agg.update(mean_grads, state, self.params) | ||
|
|
||
| def test_dpsgd_accepts_key(self): |
There was a problem hiding this comment.
could you parametrized on both seed key: int = 0 and key: jax.Array = jax.random.key(0)
I think jax.random.key might need to called in the test scope rather than in the decorator, thanks!
666c4a2
into
google-deepmind:main
Note: This function is untested on main.