Files
candle/candle-examples/examples/llama/convert_checkpoint.py
2023-07-11 19:32:10 +01:00

200 lines
7.5 KiB
Python

# Adapted from:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py
# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
import argparse
import gc
import json
import math
import os
import shutil
import warnings
import torch
import numpy as np
"""
Sample usage:
```
python src/transformers/models/llama/convert_llama_weights_to_hf.py \
--input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path
```
"""
INTERMEDIATE_SIZE_MAP = {
"7B": 11008,
"13B": 13824,
"30B": 17920,
"65B": 22016,
}
NUM_SHARDS = {
"7B": 1,
"13B": 2,
"30B": 4,
"65B": 8,
}
def compute_intermediate_size(n):
return int(math.ceil(n * 8 / 3) + 255) // 256 * 256
def read_json(path):
with open(path, "r") as f:
return json.load(f)
def write_json(text, path):
with open(path, "w") as f:
json.dump(text, f)
def write_model(model_path, input_base_path, model_size):
os.makedirs(model_path, exist_ok=True)
params = read_json(os.path.join(input_base_path, "params.json"))
num_shards = NUM_SHARDS[model_size]
n_layers = params["n_layers"]
n_heads = params["n_heads"]
n_heads_per_shard = n_heads // num_shards
dim = params["dim"]
dims_per_head = dim // n_heads
base = 10000.0
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
# permute for sliced rotary
def permute(w):
return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
# Load weights
if model_size == "7B":
# Not sharded
# (The sharded implementation would also work, but this is simpler.)
loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu")
else:
# Sharded
loaded = [
torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
for i in range(num_shards)
]
param_count = 0
all_dicts = {}
for layer_i in range(n_layers):
if model_size == "7B":
# Unsharded
state_dict = {
f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
loaded[f"layers.{layer_i}.attention.wq.weight"]
),
f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
loaded[f"layers.{layer_i}.attention.wk.weight"]
),
f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"],
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"],
f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"],
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"],
f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"],
f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"layers.{layer_i}.attention_norm.weight"],
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"layers.{layer_i}.ffn_norm.weight"],
}
else:
# Sharded
# Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share
# the same storage object, saving attention_norm and ffn_norm will save other weights too, which is
# redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned.
state_dict = {
f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][
f"layers.{layer_i}.attention_norm.weight"
].clone(),
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][
f"layers.{layer_i}.ffn_norm.weight"
].clone(),
}
state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
torch.cat(
[
loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim)
for i in range(num_shards)
],
dim=0,
).reshape(dim, dim)
)
state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
torch.cat(
[
loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(n_heads_per_shard, dims_per_head, dim)
for i in range(num_shards)
],
dim=0,
).reshape(dim, dim)
)
state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
[
loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(n_heads_per_shard, dims_per_head, dim)
for i in range(num_shards)
],
dim=0,
).reshape(dim, dim)
state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1
)
state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
)
state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1
)
state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
)
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
all_dicts |= state_dict
if model_size == "7B":
# Unsharded
state_dict = {
"model.embed_tokens.weight": loaded["tok_embeddings.weight"],
"model.norm.weight": loaded["norm.weight"],
"lm_head.weight": loaded["output.weight"],
}
else:
state_dict = {
"model.norm.weight": loaded[0]["norm.weight"],
"model.embed_tokens.weight": torch.cat(
[loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1
),
"lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
}
all_dicts |= state_dict
all_dicts = {k: v.numpy() for k, v in all_dicts.items()}
np.savez(os.path.join(model_path, "llama.npz"), **all_dicts)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_dir",
help="Location of LLaMA weights, which contains tokenizer.model and model folders",
)
parser.add_argument(
"--model_size",
choices=["7B", "13B", "30B", "65B"],
)
parser.add_argument(
"--output_dir",
help="Location to write HF model and tokenizer",
)
args = parser.parse_args()
write_model(
model_path=args.output_dir,
input_base_path=os.path.join(args.input_dir, args.model_size),
model_size=args.model_size,
)
if __name__ == "__main__":
main()