File size: 5,197 Bytes
478ce39 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import os
import json
from argparse import ArgumentParser
from glob import glob
from tqdm import tqdm
import torch
from safetensors.torch import load_file, save_file
def weight_dequant_fp8(weight_fp8, scale_inv):
"""
Dequantize FP8 weights to BF16 using scale_inv.
Args:
weight_fp8: FP8 tensor
scale_inv: Inverse scale tensor (F32)
Returns:
BF16 tensor
"""
# Convert FP8 to float32 first
weight_f32 = weight_fp8.to(torch.float32)
# Apply inverse scaling
# scale_inv shape is typically [out_features_blocks, in_features_blocks]
# We need to broadcast it properly to match weight dimensions
if scale_inv.dim() == 2:
# Expand scale_inv to match weight dimensions
out_blocks, in_blocks = scale_inv.shape
weight_blocks_out = weight_fp8.shape[0] // out_blocks
weight_blocks_in = weight_fp8.shape[1] // in_blocks
# Repeat scale_inv to match weight shape
scale_inv_expanded = scale_inv.repeat_interleave(weight_blocks_out, dim=0)
scale_inv_expanded = scale_inv_expanded.repeat_interleave(weight_blocks_in, dim=1)
weight_f32 = weight_f32 * scale_inv_expanded
else:
weight_f32 = weight_f32 * scale_inv
# Convert to BF16
return weight_f32.to(torch.bfloat16)
def main(fp8_path, bf16_path):
torch.set_default_dtype(torch.bfloat16)
os.makedirs(bf16_path, exist_ok=True)
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
with open(model_index_file, "r") as f:
model_index = json.load(f)
weight_map = model_index["weight_map"]
# Cache for loaded safetensor files
loaded_files = {}
fp8_weight_names = []
# Helper function to get tensor from the correct file
def get_tensor(tensor_name):
if tensor_name not in weight_map:
return None
file_name = weight_map[tensor_name]
if file_name not in loaded_files:
file_path = os.path.join(fp8_path, file_name)
loaded_files[file_name] = load_file(file_path, device="cuda")
return loaded_files[file_name][tensor_name]
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
safetensor_files = [f for f in safetensor_files if not f.endswith(".index.json")]
safetensor_files.sort()
print(f"Found {len(safetensor_files)} safetensor files to convert")
for safetensor_file in tqdm(safetensor_files, desc="Converting files"):
file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cuda")
loaded_files[file_name] = current_state_dict
new_state_dict = {}
for weight_name, weight in current_state_dict.items():
# Skip scale_inv tensors
if weight_name.endswith("_scale_inv"):
continue
# Check if this is an FP8 weight (F8_E4M3 has element_size of 1)
if weight.dtype == torch.float8_e4m3fn or weight.element_size() == 1:
scale_inv_name = f"{weight_name}_scale_inv"
scale_inv = get_tensor(scale_inv_name)
if scale_inv is not None:
fp8_weight_names.append(weight_name)
new_state_dict[weight_name] = weight_dequant_fp8(weight, scale_inv)
else:
print(f"Warning: Missing scale_inv tensor for {weight_name}, keeping as-is")
new_state_dict[weight_name] = weight
else:
# Already BF16 or F32, keep as-is
new_state_dict[weight_name] = weight
# Save converted weights
new_safetensor_file = os.path.join(bf16_path, file_name)
save_file(new_state_dict, new_safetensor_file)
# Memory management: keep only the 2 most recently used files
if len(loaded_files) > 2:
oldest_file = next(iter(loaded_files))
del loaded_files[oldest_file]
torch.cuda.empty_cache()
# Update model index - remove all _scale_inv entries
print("Updating model index...")
new_weight_map = {}
for weight_name, file_name in weight_map.items():
if not weight_name.endswith("_scale_inv"):
new_weight_map[weight_name] = file_name
new_model_index = {
"metadata": model_index.get("metadata", {}),
"weight_map": new_weight_map
}
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
with open(new_model_index_file, "w") as f:
json.dump(new_model_index, f, indent=2)
print(f"Conversion complete! Converted {len(fp8_weight_names)} FP8 weights to BF16")
print(f"Output saved to: {bf16_path}")
if __name__ == "__main__":
parser = ArgumentParser(description="Convert MiniMax-M2 from FP8 to BF16")
parser.add_argument("--input-fp8-hf-path", type=str, required=True,
help="Path to the FP8 model directory")
parser.add_argument("--output-bf16-hf-path", type=str, required=True,
help="Path to save the BF16 model")
args = parser.parse_args()
main(args.input_fp8_hf_path, args.output_bf16_hf_path)
|