Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions tests/jax/test_distributed_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,16 +161,21 @@ def test_distributed_gemm(
# Compare results
assert_allclose(gathered_te, gathered_jax, dtype=dtype)

def _te_sum_dense(self, x, weight, bias, contracting_dims):
def _te_sum_dense(self, x, weight, bias, contracting_dims, output_sharding):
"""TE GEMM function for gradient testing"""
return jnp.sum(dense(x, weight, bias=bias, contracting_dims=contracting_dims))
output = dense(x, weight, bias=bias, contracting_dims=contracting_dims)
if output_sharding is not None:
output = jax.lax.with_sharding_constraint(output, output_sharding)
return jnp.sum(output)

def _jax_sum_dense(self, x, weight, bias, contracting_dims):
def _jax_sum_dense(self, x, weight, bias, contracting_dims, output_sharding):
"""JAX dot function for gradient testing"""
result = (
output = (
jax.lax.dot_general(x, weight, dimension_numbers=(contracting_dims, ((), ()))) + bias
)
return jnp.sum(result)
if output_sharding is not None:
output = jax.lax.with_sharding_constraint(output, output_sharding)
return jnp.sum(output)

@pytest_parametrize_wrapper(
"device_count,mesh_shape,mesh_axes,mesh_resource",
Expand Down Expand Up @@ -213,18 +218,18 @@ def test_te_distributed_dense_grad(
# Test gradients w.r.t. all inputs
te_grad_func = jax.jit(
jax.value_and_grad(self._te_sum_dense, argnums=(0, 1, 2)),
static_argnames=("contracting_dims",),
static_argnames=("contracting_dims", "output_sharding"),
)
jax_grad_func = jax.jit(
jax.value_and_grad(self._jax_sum_dense, argnums=(0, 1, 2)),
static_argnames=("contracting_dims",),
static_argnames=("contracting_dims", "output_sharding"),
)

te_val, te_grads = te_grad_func(
x_sharded, weight_sharded, bias_sharded, contracting_dims
x_sharded, weight_sharded, bias_sharded, contracting_dims, output_sharding
)
jax_val, jax_grads = jax_grad_func(
x_sharded, weight_sharded, bias_sharded, contracting_dims
x_sharded, weight_sharded, bias_sharded, contracting_dims, output_sharding
)

# Compare forward pass
Expand Down
Loading
Loading