# Adapted from: # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py # Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved. import argparse import gc import json import math import os import shutil import warnings import torch import numpy as np """ Sample usage: ``` python src/transformers/models/llama/convert_llama_weights_to_hf.py \ --input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path ``` """ INTERMEDIATE_SIZE_MAP = { "7B": 11008, "13B": 13824, "30B": 17920, "65B": 22016, } NUM_SHARDS = { "7B": 1, "13B": 2, "30B": 4, "65B": 8, } def compute_intermediate_size(n): return int(math.ceil(n * 8 / 3) + 255) // 256 * 256 def read_json(path): with open(path, "r") as f: return json.load(f) def write_json(text, path): with open(path, "w") as f: json.dump(text, f) def write_model(model_path, input_base_path, model_size): os.makedirs(model_path, exist_ok=True) params = read_json(os.path.join(input_base_path, "params.json")) num_shards = NUM_SHARDS[model_size] n_layers = params["n_layers"] n_heads = params["n_heads"] n_heads_per_shard = n_heads // num_shards dim = params["dim"] dims_per_head = dim // n_heads base = 10000.0 inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) # permute for sliced rotary def permute(w): return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim) print(f"Fetching all parameters from the checkpoint at {input_base_path}.") # Load weights if model_size == "7B": # Not sharded # (The sharded implementation would also work, but this is simpler.) loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu") else: # Sharded loaded = [ torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu") for i in range(num_shards) ] param_count = 0 all_dicts = {} for layer_i in range(n_layers): if model_size == "7B": # Unsharded state_dict = { f"model.layers.{layer_i}.self_attn.q_proj.weight": permute( loaded[f"layers.{layer_i}.attention.wq.weight"] ), f"model.layers.{layer_i}.self_attn.k_proj.weight": permute( loaded[f"layers.{layer_i}.attention.wk.weight"] ), f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"], f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"], f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"], f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"], f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"], f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"layers.{layer_i}.attention_norm.weight"], f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"layers.{layer_i}.ffn_norm.weight"], } else: # Sharded # Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share # the same storage object, saving attention_norm and ffn_norm will save other weights too, which is # redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned. state_dict = { f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][ f"layers.{layer_i}.attention_norm.weight" ].clone(), f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][ f"layers.{layer_i}.ffn_norm.weight" ].clone(), } state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute( torch.cat( [ loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim) for i in range(num_shards) ], dim=0, ).reshape(dim, dim) ) state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute( torch.cat( [ loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(n_heads_per_shard, dims_per_head, dim) for i in range(num_shards) ], dim=0, ).reshape(dim, dim) ) state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( [ loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(n_heads_per_shard, dims_per_head, dim) for i in range(num_shards) ], dim=0, ).reshape(dim, dim) state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( [loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1 ) state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( [loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0 ) state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( [loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1 ) state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( [loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0 ) state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq all_dicts |= state_dict if model_size == "7B": # Unsharded state_dict = { "model.embed_tokens.weight": loaded["tok_embeddings.weight"], "model.norm.weight": loaded["norm.weight"], "lm_head.weight": loaded["output.weight"], } else: state_dict = { "model.norm.weight": loaded[0]["norm.weight"], "model.embed_tokens.weight": torch.cat( [loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1 ), "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0), } all_dicts |= state_dict all_dicts = {k: v.numpy() for k, v in all_dicts.items()} np.savez(os.path.join(model_path, "llama.npz"), **all_dicts) def main(): parser = argparse.ArgumentParser() parser.add_argument( "--input_dir", help="Location of LLaMA weights, which contains tokenizer.model and model folders", ) parser.add_argument( "--model_size", choices=["7B", "13B", "30B", "65B"], ) parser.add_argument( "--output_dir", help="Location to write HF model and tokenizer", ) args = parser.parse_args() write_model( model_path=args.output_dir, input_base_path=os.path.join(args.input_dir, args.model_size), model_size=args.model_size, ) if __name__ == "__main__": main()