File size: 5,197 Bytes
478ce39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143

import os
import json
from argparse import ArgumentParser
from glob import glob
from tqdm import tqdm

import torch
from safetensors.torch import load_file, save_file


def weight_dequant_fp8(weight_fp8, scale_inv):
    """
    Dequantize FP8 weights to BF16 using scale_inv.

    Args:
        weight_fp8: FP8 tensor
        scale_inv: Inverse scale tensor (F32)

    Returns:
        BF16 tensor
    """
    # Convert FP8 to float32 first
    weight_f32 = weight_fp8.to(torch.float32)

    # Apply inverse scaling
    # scale_inv shape is typically [out_features_blocks, in_features_blocks]
    # We need to broadcast it properly to match weight dimensions
    if scale_inv.dim() == 2:
        # Expand scale_inv to match weight dimensions
        out_blocks, in_blocks = scale_inv.shape
        weight_blocks_out = weight_fp8.shape[0] // out_blocks
        weight_blocks_in = weight_fp8.shape[1] // in_blocks

        # Repeat scale_inv to match weight shape
        scale_inv_expanded = scale_inv.repeat_interleave(weight_blocks_out, dim=0)
        scale_inv_expanded = scale_inv_expanded.repeat_interleave(weight_blocks_in, dim=1)

        weight_f32 = weight_f32 * scale_inv_expanded
    else:
        weight_f32 = weight_f32 * scale_inv

    # Convert to BF16
    return weight_f32.to(torch.bfloat16)


def main(fp8_path, bf16_path):
    torch.set_default_dtype(torch.bfloat16)
    os.makedirs(bf16_path, exist_ok=True)

    model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
    with open(model_index_file, "r") as f:
        model_index = json.load(f)

    weight_map = model_index["weight_map"]

    # Cache for loaded safetensor files
    loaded_files = {}
    fp8_weight_names = []

    # Helper function to get tensor from the correct file
    def get_tensor(tensor_name):
        if tensor_name not in weight_map:
            return None
        file_name = weight_map[tensor_name]
        if file_name not in loaded_files:
            file_path = os.path.join(fp8_path, file_name)
            loaded_files[file_name] = load_file(file_path, device="cuda")
        return loaded_files[file_name][tensor_name]

    safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
    safetensor_files = [f for f in safetensor_files if not f.endswith(".index.json")]
    safetensor_files.sort()

    print(f"Found {len(safetensor_files)} safetensor files to convert")

    for safetensor_file in tqdm(safetensor_files, desc="Converting files"):
        file_name = os.path.basename(safetensor_file)
        current_state_dict = load_file(safetensor_file, device="cuda")
        loaded_files[file_name] = current_state_dict

        new_state_dict = {}

        for weight_name, weight in current_state_dict.items():
            # Skip scale_inv tensors
            if weight_name.endswith("_scale_inv"):
                continue

            # Check if this is an FP8 weight (F8_E4M3 has element_size of 1)
            if weight.dtype == torch.float8_e4m3fn or weight.element_size() == 1:
                scale_inv_name = f"{weight_name}_scale_inv"
                scale_inv = get_tensor(scale_inv_name)

                if scale_inv is not None:
                    fp8_weight_names.append(weight_name)
                    new_state_dict[weight_name] = weight_dequant_fp8(weight, scale_inv)
                else:
                    print(f"Warning: Missing scale_inv tensor for {weight_name}, keeping as-is")
                    new_state_dict[weight_name] = weight
            else:
                # Already BF16 or F32, keep as-is
                new_state_dict[weight_name] = weight

        # Save converted weights
        new_safetensor_file = os.path.join(bf16_path, file_name)
        save_file(new_state_dict, new_safetensor_file)

        # Memory management: keep only the 2 most recently used files
        if len(loaded_files) > 2:
            oldest_file = next(iter(loaded_files))
            del loaded_files[oldest_file]
            torch.cuda.empty_cache()

    # Update model index - remove all _scale_inv entries
    print("Updating model index...")
    new_weight_map = {}
    for weight_name, file_name in weight_map.items():
        if not weight_name.endswith("_scale_inv"):
            new_weight_map[weight_name] = file_name

    new_model_index = {
        "metadata": model_index.get("metadata", {}),
        "weight_map": new_weight_map
    }

    new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
    with open(new_model_index_file, "w") as f:
        json.dump(new_model_index, f, indent=2)

    print(f"Conversion complete! Converted {len(fp8_weight_names)} FP8 weights to BF16")
    print(f"Output saved to: {bf16_path}")


if __name__ == "__main__":
    parser = ArgumentParser(description="Convert MiniMax-M2 from FP8 to BF16")
    parser.add_argument("--input-fp8-hf-path", type=str, required=True,
                        help="Path to the FP8 model directory")
    parser.add_argument("--output-bf16-hf-path", type=str, required=True,
                        help="Path to save the BF16 model")
    args = parser.parse_args()

    main(args.input_fp8_hf_path, args.output_bf16_hf_path)