Resurrect the llama npy support. (#140)

This commit is contained in:
Laurent Mazare
2023-07-11 19:32:10 +01:00
committed by GitHub
parent 760f1d7055
commit 37cad85869
6 changed files with 264 additions and 90 deletions

View File

@ -1,68 +1,199 @@
# Adapted from https://github.com/Lightning-AI/lit-llama/blob/main/scripts/convert_checkpoint.py
import sys
# Adapted from:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py
# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
import argparse
import gc
import json
import math
import os
import shutil
import warnings
import torch
import numpy as np
from typing import Dict
from pathlib import Path
def tr(v):
return np.ascontiguousarray(np.transpose(v))
"""
Sample usage:
def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch.dtype = torch.float32) -> Dict[str, torch.Tensor]:
print("start conv")
```
python src/transformers/models/llama/convert_llama_weights_to_hf.py \
--input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path
```
"""
def get_and_remove(key, transpose=False):
v = state_dict[key].to(dtype).numpy()
if transpose:
v = tr(v)
del state_dict[key]
return v
INTERMEDIATE_SIZE_MAP = {
"7B": 11008,
"13B": 13824,
"30B": 17920,
"65B": 22016,
}
NUM_SHARDS = {
"7B": 1,
"13B": 2,
"30B": 4,
"65B": 8,
}
converted = {}
converted["transformer.wte.weight"] = get_and_remove("tok_embeddings.weight")
converted["lm_head.weight"] = get_and_remove("output.weight", transpose=True)
converted["transformer.ln_f.scale"] = get_and_remove("norm.weight")
for layer_idx in sorted(set([k.split(".")[1] for k in state_dict if k.startswith("layers")])):
print(layer_idx)
def compute_intermediate_size(n):
return int(math.ceil(n * 8 / 3) + 255) // 256 * 256
# attention
# the wq, wk, wv from the FB model are stacked in our model as c_attn
converted[f"transformer.h.{layer_idx}.attn.c_attn.weight"] = tr(np.concatenate(
(
get_and_remove(f"layers.{layer_idx}.attention.wq.weight"),
get_and_remove(f"layers.{layer_idx}.attention.wk.weight"),
get_and_remove(f"layers.{layer_idx}.attention.wv.weight"),
def read_json(path):
with open(path, "r") as f:
return json.load(f)
def write_json(text, path):
with open(path, "w") as f:
json.dump(text, f)
def write_model(model_path, input_base_path, model_size):
os.makedirs(model_path, exist_ok=True)
params = read_json(os.path.join(input_base_path, "params.json"))
num_shards = NUM_SHARDS[model_size]
n_layers = params["n_layers"]
n_heads = params["n_heads"]
n_heads_per_shard = n_heads // num_shards
dim = params["dim"]
dims_per_head = dim // n_heads
base = 10000.0
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
# permute for sliced rotary
def permute(w):
return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
# Load weights
if model_size == "7B":
# Not sharded
# (The sharded implementation would also work, but this is simpler.)
loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu")
else:
# Sharded
loaded = [
torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
for i in range(num_shards)
]
param_count = 0
all_dicts = {}
for layer_i in range(n_layers):
if model_size == "7B":
# Unsharded
state_dict = {
f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
loaded[f"layers.{layer_i}.attention.wq.weight"]
),
f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
loaded[f"layers.{layer_i}.attention.wk.weight"]
),
f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"],
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"],
f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"],
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"],
f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"],
f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"layers.{layer_i}.attention_norm.weight"],
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"layers.{layer_i}.ffn_norm.weight"],
}
else:
# Sharded
# Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share
# the same storage object, saving attention_norm and ffn_norm will save other weights too, which is
# redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned.
state_dict = {
f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][
f"layers.{layer_i}.attention_norm.weight"
].clone(),
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][
f"layers.{layer_i}.ffn_norm.weight"
].clone(),
}
state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
torch.cat(
[
loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim)
for i in range(num_shards)
],
dim=0,
).reshape(dim, dim)
)
))
converted[f"transformer.h.{layer_idx}.attn.c_proj.weight"] = tr(get_and_remove(
f"layers.{layer_idx}.attention.wo.weight"
))
# mlp
converted[f"transformer.h.{layer_idx}.mlp.c_fc1.weight"] = get_and_remove(
f"layers.{layer_idx}.feed_forward.w1.weight", transpose=True,
state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
torch.cat(
[
loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(n_heads_per_shard, dims_per_head, dim)
for i in range(num_shards)
],
dim=0,
).reshape(dim, dim)
)
converted[f"transformer.h.{layer_idx}.mlp.c_proj.weight"] = get_and_remove(
f"layers.{layer_idx}.feed_forward.w2.weight", transpose=True,
)
converted[f"transformer.h.{layer_idx}.mlp.c_fc2.weight"] = get_and_remove(
f"layers.{layer_idx}.feed_forward.w3.weight", transpose=True,
)
# rms norm
converted[f"transformer.h.{layer_idx}.rms_1.scale"] = get_and_remove(f"layers.{layer_idx}.attention_norm.weight")
converted[f"transformer.h.{layer_idx}.rms_2.scale"] = get_and_remove(f"layers.{layer_idx}.ffn_norm.weight")
return converted
state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
[
loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(n_heads_per_shard, dims_per_head, dim)
for i in range(num_shards)
],
dim=0,
).reshape(dim, dim)
state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1
)
state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
)
state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1
)
state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
)
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
all_dicts |= state_dict
if model_size == "7B":
# Unsharded
state_dict = {
"model.embed_tokens.weight": loaded["tok_embeddings.weight"],
"model.norm.weight": loaded["norm.weight"],
"lm_head.weight": loaded["output.weight"],
}
else:
state_dict = {
"model.norm.weight": loaded[0]["norm.weight"],
"model.embed_tokens.weight": torch.cat(
[loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1
),
"lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
}
all_dicts |= state_dict
all_dicts = {k: v.numpy() for k, v in all_dicts.items()}
np.savez(os.path.join(model_path, "llama.npz"), **all_dicts)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_dir",
help="Location of LLaMA weights, which contains tokenizer.model and model folders",
)
parser.add_argument(
"--model_size",
choices=["7B", "13B", "30B", "65B"],
)
parser.add_argument(
"--output_dir",
help="Location to write HF model and tokenizer",
)
args = parser.parse_args()
write_model(
model_path=args.output_dir,
input_base_path=os.path.join(args.input_dir, args.model_size),
model_size=args.model_size,
)
def convert_weights(llama_ckpt, *, output_npz: Path = Path("llama.npz"), dtype: str = "float32") -> None:
dt = getattr(torch, dtype, None)
if not isinstance(dt, torch.dtype):
raise ValueError(f"{dtype} is not a valid dtype.")
checkpoint = torch.load(llama_ckpt, map_location="cpu")
converted = convert_state_dict(checkpoint, dtype=dt)
del checkpoint
np.savez(output_npz, **converted)
if __name__ == "__main__":
if len(sys.argv) != 2:
raise ValueError(f"usage: convert_checkpoint.py ..../LLaMA/7B/consolidated.00.pth")
convert_weights(sys.argv[1])
main()

View File

@ -144,8 +144,14 @@ fn main() -> Result<()> {
let config = Config::config_7b();
let cache = model::Cache::new(!args.no_kv_cache, &config, &device);
let (llama, tokenizer_filename) = match args.npy {
Some(_) => {
todo!("fix numpy handling if we continue supporting it")
Some(filename) => {
let tensors = Tensor::read_npz(filename)?
.into_iter()
.map(|(n, t)| Ok((n, t.to_dtype(DTYPE)?)))
.collect::<Result<std::collections::HashMap<String, Tensor>>>()?;
let vb = VarBuilder::from_tensors(tensors, DTYPE, &device);
let tokenizer = std::path::PathBuf::from("llama-tokenizer.json");
(Llama::load(vb, &cache, &config)?, tokenizer)
}
None => {
let api = Api::new()?;