Conversation
| # self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv) | ||
| # self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2) |
There was a problem hiding this comment.
We have changed how we do caching as seen in the bellow lines. Do we have a reason to keep these comments?
| # self.k_cache[:bsz, start_pos:end_pos] = k | ||
| # self.v_cache[:bsz, start_pos:end_pos] = v |
| # assert args.n_routed_experts % world_size == 0 | ||
| # self.n_routed_experts = args.n_routed_experts | ||
| # self.n_local_experts = args.n_routed_experts // world_size | ||
| # self.n_activated_experts = args.n_activated_experts | ||
| # self.experts_start_idx = rank * self.n_local_experts | ||
| # self.experts_end_idx = self.experts_start_idx + self.n_local_experts | ||
| self.gate = Gate(model_args) | ||
| # self.experts = nn.ModuleList( | ||
| # [ | ||
| # Expert(args.dim, args.moe_inter_dim) | ||
| # if self.experts_start_idx <= i < self.experts_end_idx | ||
| # else None | ||
| # for i in range(self.n_routed_experts) | ||
| # ] | ||
| # ) |
There was a problem hiding this comment.
Do we have reason to leave these as comments rather than removing them?
| # if seqlen > 1: | ||
| # mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1) |
There was a problem hiding this comment.
Do we have reason to leave these as comments rather than removing them?
| else: | ||
| group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1) | ||
| indices = group_scores.topk(self.topk_groups, dim=-1)[1] | ||
| print('i am here') |
There was a problem hiding this comment.
Test print. Please remove. If appropriate, we could use some logging at info level.
I see there are a couple other prints on model functions. I would consider applying the same criteria to those.
| # Install torchax | ||
| RUN git clone https://github.com/pytorch/xla.git | ||
| WORKDIR /workspaces/xla/torchax | ||
| RUN git checkout hanq_torchax1 |
There was a problem hiding this comment.
Is this a temporary fix while pytorch/xla@master...hanq_torchax1 is not merged?
|
|
||
| tokens = name.split(".") | ||
| for i, t in enumerate(tokens): | ||
| if is_integer(t): |
There was a problem hiding this comment.
Why not using something like type(t) is int?
| name0 = "tp0" | ||
| # name1 = "tp1" |
There was a problem hiding this comment.
Should we add the definition of "name0"?
| # TODO(https://github.com/AI-Hypercomputer/torchprime/issues/75): Fix the failure on torch 2.6, | ||
| # then enable the test unconditionally. | ||
| @pytest.mark.deepseek | ||
| def test_single_device_compile(): |
There was a problem hiding this comment.
Is this test no longer useful?
There was a problem hiding this comment.
This file contains many helpers for sharding. Should we create a file that contains the sharding tooling written here?
In that case, adding unit tests for these sharding methods could be helpful
No description provided.