| | import typing |
| | from collections.abc import Callable |
| | from collections import defaultdict |
| | from typing import Any, Dict, TYPE_CHECKING, Optional, Tuple, List |
| |
|
| | import torch |
| | import copy |
| |
|
| | from torch import Tensor |
| | from torch.nn import Module |
| | import torch.nn.functional as F |
| |
|
| | if TYPE_CHECKING: |
| | Base = Module[Tensor] |
| | else: |
| | Base = Module |
| |
|
| |
|
| | MOE_TOP_K = 2 |
| | Constant = 2 |
| |
|
| |
|
| | class CopyExpert(torch.nn.Module): |
| | def __init__(self, expert): |
| | super(CopyExpert, self).__init__() |
| | pass |
| |
|
| | def forward(self, inputs): |
| | return inputs |
| |
|
| |
|
| | class ZeroExpert(torch.nn.Module): |
| | def __init__(self, expert): |
| | super(ZeroExpert, self).__init__() |
| | pass |
| |
|
| | def forward(self, inputs): |
| | return torch.zeros_like(inputs).to(inputs.dtype).to(inputs.device) |
| |
|
| |
|
| | class ConstantExpert(torch.nn.Module): |
| | def __init__(self, expert): |
| | super(ConstantExpert, self).__init__() |
| | self.constant = torch.nn.Parameter( |
| | torch.empty((expert.hidden_size))) |
| | torch.nn.init.normal_(self.constant) |
| |
|
| | self.wg = torch.nn.Linear(expert.hidden_size, 2, bias=False) |
| | self.softmax = torch.nn.Softmax(dim=-1) |
| |
|
| | def forward(self, inputs): |
| | |
| | weight = self.wg(inputs) |
| | weight = self.softmax(weight) |
| | return torch.einsum('b,bd->bd', [weight[:, 0].type_as(inputs), inputs]) + torch.einsum( |
| | 'b,d->bd', [weight[:, 1].type_as(inputs), self.constant.type_as(inputs)]) |
| |
|
| |
|
| | def gating(logits: Tensor, moe_use_mixtral_gating=False, moe_use_logits_norm=False, moe_gate_norm_std=1.0) -> Dict[int, List[Tuple[int, float]]]: |
| | |
| | num_experts = logits.size(1) |
| | if moe_use_mixtral_gating: |
| | if moe_use_logits_norm: |
| | target_std = moe_gate_norm_std |
| | logits_std = logits.std(dim=1, keepdim=True) |
| | logits = logits / (logits_std / target_std) |
| | gates, indices = torch.topk(logits, k=MOE_TOP_K, dim=1) |
| | gates = F.softmax(gates, dim=1) |
| | else: |
| | target_std = moe_gate_norm_std |
| | if moe_use_logits_norm: |
| | logits_std = logits.std(dim=1, keepdim=True) |
| | gates = F.softmax(logits / (logits_std / target_std), dim=1) |
| | else: |
| | gates = F.softmax(logits, dim=1) |
| | |
| | |
| | gates, indices = torch.topk(gates, k=MOE_TOP_K, dim=1) |
| | gates = torch.where(indices==(num_experts-1), torch.zeros_like(gates).to(gates.dtype).to(gates.device), gates) |
| | gates /= torch.sum(gates, dim=1, keepdim=True) |
| |
|
| | expert_info = defaultdict(list) |
| | for expert_id in range(num_experts): |
| | token_ids, score_ids = torch.nonzero(indices == expert_id, as_tuple=True) |
| | expert_info[expert_id] = [token_ids, gates[token_ids, score_ids]] |
| |
|
| | return expert_info |
| |
|
| |
|
| | class Router(Module): |
| | def __init__(self, |
| | model_dim: int, |
| | num_experts: int, |
| | moe_use_mixtral_gating: bool, |
| | moe_2layer_gate: bool, |
| | moe_use_logits_norm: bool, |
| | moe_gate_norm_std: float, |
| | ) -> None: |
| | super().__init__() |
| |
|
| | if moe_2layer_gate: |
| | self.wg = torch.nn.Sequential( |
| | torch.nn.Linear(model_dim, num_experts * 8, bias=False).float(), |
| | torch.nn.Tanh(), |
| | torch.nn.Linear(num_experts * 8, num_experts, bias=False).float()).float() |
| | else: |
| | self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float() |
| |
|
| | self.gate_map = torch.nn.Linear(num_experts, num_experts, bias=False) |
| |
|
| | self.gate = gating |
| | self.moe_use_mixtral_gating = moe_use_mixtral_gating |
| | self.moe_use_logits_norm = moe_use_logits_norm |
| | self.moe_gate_norm_std = moe_gate_norm_std |
| |
|
| | def forward(self, input: torch.Tensor, gate_residual=None) -> Dict[int, List[Tuple[int, float]]]: |
| | if isinstance(self.wg, torch.nn.Linear): |
| | if self.wg.weight.dtype != torch.float32: |
| | self.wg = self.wg.float() |
| | setattr(self.wg.weight, 'router', True) |
| | else: |
| | if self.wg[0].weight.dtype != torch.float32: |
| | self.wg = self.wg.float() |
| | setattr(self.wg[0].weight, "router", True) |
| | setattr(self.wg[2].weight, "router", True) |
| | input_fp32 = input.float() |
| | logits = self.wg(input_fp32) |
| |
|
| | if gate_residual is not None: |
| | gate_residual = self.gate_map(gate_residual.to(self.gate_map.weight.dtype)) |
| | logits += gate_residual |
| |
|
| | gate_output = self.gate(logits, self.moe_use_mixtral_gating, self.moe_use_logits_norm, self.moe_gate_norm_std) |
| |
|
| | return gate_output, logits |
| |
|
| |
|
| | class Experts(torch.nn.Module): |
| | def __init__(self, expert, num_local_experts=1): |
| | super(Experts, self).__init__() |
| |
|
| | self.experts = torch.nn.ModuleList( |
| | [copy.deepcopy(expert) for _ in range(num_local_experts - 2 - Constant)] + |
| | [ConstantExpert(expert) for _ in range(Constant)] + |
| | [CopyExpert(expert), ZeroExpert(expert)]) |
| |
|
| | def forward(self, inputs): |
| | raise NotImplementedError |
| |
|
| |
|
| | class MOELayer(Base): |
| | def __init__(self, |
| | gate: Module, |
| | experts: Module, |
| | ep_size, |
| | num_local_experts: int, |
| | moe_use_mixtral_gating: bool, |
| | moe_feature_no_mul_topk: bool) -> None: |
| | super().__init__() |
| | self.gate = gate |
| | self.experts = experts |
| | self.ep_size = ep_size |
| | self.num_local_experts = num_local_experts |
| | self.moe_use_mixtral_gating = moe_use_mixtral_gating |
| | self.moe_feature_no_mul_topk = moe_feature_no_mul_topk |
| |
|
| | def forward(self, *input: Tensor, gate_residual=None, **kwargs: Any) -> Tensor: |
| | d_model = input[0].shape[-1] |
| | reshaped_input = input[0].reshape(-1, d_model) |
| | output = torch.zeros_like(reshaped_input) |
| | expert_info, gate_residual = self.gate(reshaped_input, gate_residual) |
| | if not (self.moe_use_mixtral_gating or self.moe_feature_no_mul_topk): |
| | reshaped_input *= MOE_TOP_K |
| | for expert, token_indices_and_gates in expert_info.items(): |
| | indices, gating = token_indices_and_gates |
| | gating = gating.unsqueeze(-1) |
| | tokens = reshaped_input.index_select(dim=0, index=indices) |
| | expert_output = self.experts.experts[expert](tokens) |
| | expert_output *= gating |
| | output.index_add_(dim=0, index=indices, source=expert_output) |
| | output = output.reshape(input[0].shape) |
| |
|
| | return output, gate_residual |
| |
|
| |
|
| | class MOE(torch.nn.Module): |
| | def __init__(self, |
| | hidden_size, |
| | expert, |
| | num_experts=1, |
| | ep_size=1, |
| | moe_use_mixtral_gating=False, |
| | moe_2layer_gate=True, |
| | moe_use_logits_norm=False, |
| | moe_gate_norm_std=1.0, |
| | moe_feature_no_mul_topk=False): |
| | super(MOE, self).__init__() |
| |
|
| | self.ep_size = ep_size |
| | self.num_experts = num_experts |
| | self.num_local_experts = num_experts // self.ep_size |
| | self.moe_use_mixtral_gating = moe_use_mixtral_gating |
| | self.moe_2layer_gate = moe_2layer_gate |
| | self.moe_use_logits_norm = moe_use_logits_norm |
| | self.moe_gate_norm_std = moe_gate_norm_std |
| | self.moe_feature_no_mul_topk = moe_feature_no_mul_topk |
| |
|
| | experts = Experts(expert, self.num_local_experts) |
| | self.moe = MOELayer(Router(hidden_size, |
| | num_experts, |
| | self.moe_use_mixtral_gating, |
| | self.moe_2layer_gate, |
| | self.moe_use_logits_norm, |
| | self.moe_gate_norm_std), |
| | experts, |
| | self.ep_size, |
| | self.num_local_experts, |
| | self.moe_use_mixtral_gating, |
| | self.moe_feature_no_mul_topk, |
| | ) |
| |
|
| | def forward(self, hidden_states, used_token=None, gate_residual=None): |
| | output, gate_residual = self.moe(hidden_states, used_token, gate_residual=gate_residual) |
| | return output, gate_residual |
| |
|