Skip to content

Subclass MLX and JAX from abstract types.py and shared types_test_base.py#9

Open
JulianSlzr wants to merge 2 commits intogoogle:mlxfrom
JulianSlzr:mlx-abstract-types
Open

Subclass MLX and JAX from abstract types.py and shared types_test_base.py#9
JulianSlzr wants to merge 2 commits intogoogle:mlxfrom
JulianSlzr:mlx-abstract-types

Conversation

@JulianSlzr
Copy link
Collaborator

Having a shared abstract parent leads to uniform interfaces and shared tests while letting per-backend implementations be flexible. This demonstrate the change for types.py for MLX and JAX.

"""Returns a copy of the config with updated fields."""


class Steppable(metaclass=abc.ABCMeta):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably need these as well:

def get_accumulated_input_latency(self, input_latency: int) -> int:
def get_accumulated_output_latency(self, output_latency: int) -> int:

and WDYT of receptive_field?

import numpy as np

# We'll need a way for subclasses to provide specific factory methods/backends
class SequenceTest(parameterized.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SequenceLayerTest is the base test suite for almost all tests

it has some very commonly used assertions like

assertSequencesClose
assertSequencesNotClose
assertSequencesEqual
assertSequencesNotEqual
assertAllEqual
assertAllClose
assertNotAllEqual
assertNotAllClose

wondering if we should have a base abstract SequenceLayerTest class that each abstract test class here inherits , then platform-specific subclasses can mixin a class that defines all the methods needed by the abstract SequenceLayerTest


def get_backend(self):
return jnp

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT instead of

def sequence_type(self):
  return sl.Sequence

def masked_sequence_type(self):
  return sl.MaskedSequence

to make it easier to access from_lengths from_values, etc.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants