Spaces:
Running
Running
File size: 3,953 Bytes
590a604 5a20c96 d18b34d 5a20c96 d18b34d 5a20c96 d18b34d 5a20c96 d18b34d b43ba56 5a20c96 d18b34d b43ba56 d18b34d 5a20c96 d18b34d b43ba56 d18b34d 5a20c96 d18b34d 5a20c96 d18b34d 5a20c96 b43ba56 d18b34d 5a20c96 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
"""Position-wise Feed-Forward Network.
This module implements the FFN sublayer used in Transformer blocks:
- Standard FFN: Two linear layers with activation (GELU/ReLU)
- Gated FFN: SwiGLU (LLaMA-style) or Gated-GELU (T5/FLAN-T5 style)
Author: Oliver Perrin
Date: 2025-10-23
"""
from typing import Literal, Optional
import torch
import torch.nn as nn
import torch.nn.init as init
class FeedForward(nn.Module):
"""
FFN(x) = max(0, xW₁ + b₁)W₂ + b₂
Or with GELU: FFN(x) = GELU(xW₁ + b₁)W₂ + b₂
Or with SwiGLU: FFN(x) = (Swish(xW_gate) * xW_up)W_down
Or with gated-gelu: FFN(x) = (GELU(xW_gate) * xW_up)W_down (T5/FLAN-T5 style)
"""
def __init__(
self,
d_model: int,
d_ff: int,
dropout: float = 0.1,
activation: Literal["gelu", "relu", "swiglu", "gated-gelu"] = "gelu",
quantization: Optional[str] = None,
):
super().__init__()
self.activation_type = activation
# Select Linear layer type based on quantization
Linear = nn.Linear
kwargs = {}
if quantization == "4bit":
try:
import bitsandbytes as bnb
Linear = bnb.nn.Linear4bit # type: ignore
kwargs = {"compute_dtype": torch.bfloat16, "quant_type": "nf4"}
except (ImportError, AttributeError):
print("bitsandbytes not installed or incompatible, falling back to nn.Linear")
elif quantization == "8bit":
try:
import bitsandbytes as bnb
Linear = bnb.nn.Linear8bitLt # type: ignore
except (ImportError, AttributeError):
print("bitsandbytes not installed or incompatible, falling back to nn.Linear")
if activation in ("swiglu", "gated-gelu"):
# Gated FFN requires 3 linear layers: Gate, Up, Down
# - swiglu uses SiLU (Swish) activation (LLaMA style)
# - gated-gelu uses GELU activation (T5/FLAN-T5 style)
self.linear_gate = Linear(d_model, d_ff, **kwargs) # Gate projection (wi_0)
self.linear1 = Linear(d_model, d_ff, **kwargs) # Up projection (wi_1)
self.linear2 = Linear(d_ff, d_model, **kwargs) # Down projection (wo)
if activation == "swiglu":
self.activation = nn.SiLU() # Swish activation
else: # gated-gelu
self.activation = (
nn.GELU()
) # GELU activation (T5 uses gelu_new which is very close)
# Init gate
if not quantization:
init.xavier_uniform_(self.linear_gate.weight)
init.zeros_(self.linear_gate.bias)
else:
self.linear1 = Linear(d_model, d_ff, **kwargs) # w_1
self.activation = nn.GELU() if activation == "gelu" else nn.ReLU()
self.linear2 = Linear(d_ff, d_model, **kwargs) # w_2
self.dropout = nn.Dropout(dropout)
# Weight Initialization
if not quantization:
init.xavier_uniform_(self.linear1.weight)
init.zeros_(self.linear1.bias)
init.xavier_uniform_(self.linear2.weight)
init.zeros_(self.linear2.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (batch, seq_len, d_model)
returns: (batch, seq_len, d_model)
"""
if self.activation_type in ("swiglu", "gated-gelu"):
# Gated FFN: (activation(xW_gate) * xW_up) W_down
gate = self.activation(self.linear_gate(x))
up = self.linear1(x)
x = gate * up
x = self.dropout(x)
x = self.linear2(x)
else:
x = self.linear1(x) # (batch, seq_len, d_ff)
x = self.activation(x) # activation
x = self.dropout(x) # dropout
x = self.linear2(x) # (batch, seq_len, d_model)
return x
|