Subclass MLX and JAX from abstract types.py and shared types_test_base.py#9
Subclass MLX and JAX from abstract types.py and shared types_test_base.py#9JulianSlzr wants to merge 2 commits intogoogle:mlxfrom
types.py and shared types_test_base.py#9Conversation
Co-authored-by: David Braun <2096055+DBraun@users.noreply.github.com>
| """Returns a copy of the config with updated fields.""" | ||
|
|
||
|
|
||
| class Steppable(metaclass=abc.ABCMeta): |
There was a problem hiding this comment.
probably need these as well:
def get_accumulated_input_latency(self, input_latency: int) -> int:
def get_accumulated_output_latency(self, output_latency: int) -> int:
and WDYT of receptive_field?
| import numpy as np | ||
|
|
||
| # We'll need a way for subclasses to provide specific factory methods/backends | ||
| class SequenceTest(parameterized.TestCase): |
There was a problem hiding this comment.
SequenceLayerTest is the base test suite for almost all tests
it has some very commonly used assertions like
assertSequencesClose
assertSequencesNotClose
assertSequencesEqual
assertSequencesNotEqual
assertAllEqual
assertAllClose
assertNotAllEqual
assertNotAllClose
wondering if we should have a base abstract SequenceLayerTest class that each abstract test class here inherits , then platform-specific subclasses can mixin a class that defines all the methods needed by the abstract SequenceLayerTest
|
|
||
| def get_backend(self): | ||
| return jnp | ||
|
|
There was a problem hiding this comment.
WDYT instead of
def sequence_type(self):
return sl.Sequence
def masked_sequence_type(self):
return sl.MaskedSequence
to make it easier to access from_lengths from_values, etc.
Having a shared abstract parent leads to uniform interfaces and shared tests while letting per-backend implementations be flexible. This demonstrate the change for
types.pyfor MLX and JAX.