Spaces:
Running
Running
| """Position-wise Feed-Forward Network. | |
| This module implements the FFN sublayer used in Transformer blocks: | |
| - Standard FFN: Two linear layers with activation (GELU/ReLU) | |
| - Gated FFN: SwiGLU (LLaMA-style) or Gated-GELU (T5/FLAN-T5 style) | |
| Author: Oliver Perrin | |
| Date: 2025-10-23 | |
| """ | |
| from typing import Literal, Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.init as init | |
| class FeedForward(nn.Module): | |
| """ | |
| FFN(x) = max(0, xW₁ + b₁)W₂ + b₂ | |
| Or with GELU: FFN(x) = GELU(xW₁ + b₁)W₂ + b₂ | |
| Or with SwiGLU: FFN(x) = (Swish(xW_gate) * xW_up)W_down | |
| Or with gated-gelu: FFN(x) = (GELU(xW_gate) * xW_up)W_down (T5/FLAN-T5 style) | |
| """ | |
| def __init__( | |
| self, | |
| d_model: int, | |
| d_ff: int, | |
| dropout: float = 0.1, | |
| activation: Literal["gelu", "relu", "swiglu", "gated-gelu"] = "gelu", | |
| quantization: Optional[str] = None, | |
| ): | |
| super().__init__() | |
| self.activation_type = activation | |
| # Select Linear layer type based on quantization | |
| Linear = nn.Linear | |
| kwargs = {} | |
| if quantization == "4bit": | |
| try: | |
| import bitsandbytes as bnb | |
| Linear = bnb.nn.Linear4bit # type: ignore | |
| kwargs = {"compute_dtype": torch.bfloat16, "quant_type": "nf4"} | |
| except (ImportError, AttributeError): | |
| print("bitsandbytes not installed or incompatible, falling back to nn.Linear") | |
| elif quantization == "8bit": | |
| try: | |
| import bitsandbytes as bnb | |
| Linear = bnb.nn.Linear8bitLt # type: ignore | |
| except (ImportError, AttributeError): | |
| print("bitsandbytes not installed or incompatible, falling back to nn.Linear") | |
| if activation in ("swiglu", "gated-gelu"): | |
| # Gated FFN requires 3 linear layers: Gate, Up, Down | |
| # - swiglu uses SiLU (Swish) activation (LLaMA style) | |
| # - gated-gelu uses GELU activation (T5/FLAN-T5 style) | |
| self.linear_gate = Linear(d_model, d_ff, **kwargs) # Gate projection (wi_0) | |
| self.linear1 = Linear(d_model, d_ff, **kwargs) # Up projection (wi_1) | |
| self.linear2 = Linear(d_ff, d_model, **kwargs) # Down projection (wo) | |
| if activation == "swiglu": | |
| self.activation = nn.SiLU() # Swish activation | |
| else: # gated-gelu | |
| self.activation = ( | |
| nn.GELU() | |
| ) # GELU activation (T5 uses gelu_new which is very close) | |
| # Init gate | |
| if not quantization: | |
| init.xavier_uniform_(self.linear_gate.weight) | |
| init.zeros_(self.linear_gate.bias) | |
| else: | |
| self.linear1 = Linear(d_model, d_ff, **kwargs) # w_1 | |
| self.activation = nn.GELU() if activation == "gelu" else nn.ReLU() | |
| self.linear2 = Linear(d_ff, d_model, **kwargs) # w_2 | |
| self.dropout = nn.Dropout(dropout) | |
| # Weight Initialization | |
| if not quantization: | |
| init.xavier_uniform_(self.linear1.weight) | |
| init.zeros_(self.linear1.bias) | |
| init.xavier_uniform_(self.linear2.weight) | |
| init.zeros_(self.linear2.bias) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| x: (batch, seq_len, d_model) | |
| returns: (batch, seq_len, d_model) | |
| """ | |
| if self.activation_type in ("swiglu", "gated-gelu"): | |
| # Gated FFN: (activation(xW_gate) * xW_up) W_down | |
| gate = self.activation(self.linear_gate(x)) | |
| up = self.linear1(x) | |
| x = gate * up | |
| x = self.dropout(x) | |
| x = self.linear2(x) | |
| else: | |
| x = self.linear1(x) # (batch, seq_len, d_ff) | |
| x = self.activation(x) # activation | |
| x = self.dropout(x) # dropout | |
| x = self.linear2(x) # (batch, seq_len, d_model) | |
| return x | |