Resurrect the llama npy support. (#140)

This commit is contained in:
Laurent Mazare
2023-07-11 19:32:10 +01:00
committed by GitHub
parent 760f1d7055
commit 37cad85869
6 changed files with 264 additions and 90 deletions

View File

@ -139,6 +139,9 @@ pub enum Error {
rhs_stride: Vec<usize>, rhs_stride: Vec<usize>,
mnk: (usize, usize, usize), mnk: (usize, usize, usize),
}, },
#[error("cannot find tensor {path}")]
CannotFindTensor { path: String },
} }
pub type Result<T> = std::result::Result<T, Error>; pub type Result<T> = std::result::Result<T, Error>;

View File

@ -1,10 +1,10 @@
//! Numpy support for literals. //! Numpy support for tensors.
//! //!
//! The spec for the npy format can be found in //! The spec for the npy format can be found in
//! [npy-format](https://docs.scipy.org/doc/numpy-1.14.2/neps/npy-format.html). //! [npy-format](https://docs.scipy.org/doc/numpy-1.14.2/neps/npy-format.html).
//! The functions from this module can be used to read literals from npy/npz files //! The functions from this module can be used to read tensors from npy/npz files
//! or write literals to these files. A npy file contains a single literal (unnamed) //! or write tensors to these files. A npy file contains a single tensor (unnamed)
//! whereas a npz file can contain multiple named literals. npz files are also compressed. //! whereas a npz file can contain multiple named tensors. npz files are also compressed.
//! //!
//! These two formats are easy to use in Python using the numpy library. //! These two formats are easy to use in Python using the numpy library.
//! //!
@ -232,7 +232,7 @@ impl Tensor {
} }
} }
/// Reads a npy file and return the stored multi-dimensional array as a literal. /// Reads a npy file and return the stored multi-dimensional array as a tensor.
pub fn read_npy<T: AsRef<Path>>(path: T) -> Result<Self> { pub fn read_npy<T: AsRef<Path>>(path: T) -> Result<Self> {
let mut reader = File::open(path.as_ref())?; let mut reader = File::open(path.as_ref())?;
let header = read_header(&mut reader)?; let header = read_header(&mut reader)?;

View File

@ -10,3 +10,10 @@ pub fn get_num_threads() -> usize {
Some(_) | None => num_cpus::get(), Some(_) | None => num_cpus::get(),
} }
} }
pub fn has_mkl() -> bool {
#[cfg(feature = "mkl")]
return true;
#[cfg(not(feature = "mkl"))]
return false;
}

View File

@ -1,68 +1,199 @@
# Adapted from https://github.com/Lightning-AI/lit-llama/blob/main/scripts/convert_checkpoint.py # Adapted from:
import sys # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py
# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
import argparse
import gc
import json
import math
import os
import shutil
import warnings
import torch import torch
import numpy as np import numpy as np
from typing import Dict
from pathlib import Path
def tr(v): """
return np.ascontiguousarray(np.transpose(v)) Sample usage:
def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch.dtype = torch.float32) -> Dict[str, torch.Tensor]: ```
print("start conv") python src/transformers/models/llama/convert_llama_weights_to_hf.py \
--input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path
```
"""
def get_and_remove(key, transpose=False): INTERMEDIATE_SIZE_MAP = {
v = state_dict[key].to(dtype).numpy() "7B": 11008,
if transpose: "13B": 13824,
v = tr(v) "30B": 17920,
del state_dict[key] "65B": 22016,
return v }
NUM_SHARDS = {
"7B": 1,
"13B": 2,
"30B": 4,
"65B": 8,
}
converted = {}
converted["transformer.wte.weight"] = get_and_remove("tok_embeddings.weight")
converted["lm_head.weight"] = get_and_remove("output.weight", transpose=True)
converted["transformer.ln_f.scale"] = get_and_remove("norm.weight")
for layer_idx in sorted(set([k.split(".")[1] for k in state_dict if k.startswith("layers")])): def compute_intermediate_size(n):
print(layer_idx) return int(math.ceil(n * 8 / 3) + 255) // 256 * 256
# attention
# the wq, wk, wv from the FB model are stacked in our model as c_attn def read_json(path):
converted[f"transformer.h.{layer_idx}.attn.c_attn.weight"] = tr(np.concatenate( with open(path, "r") as f:
( return json.load(f)
get_and_remove(f"layers.{layer_idx}.attention.wq.weight"),
get_and_remove(f"layers.{layer_idx}.attention.wk.weight"),
get_and_remove(f"layers.{layer_idx}.attention.wv.weight"), def write_json(text, path):
with open(path, "w") as f:
json.dump(text, f)
def write_model(model_path, input_base_path, model_size):
os.makedirs(model_path, exist_ok=True)
params = read_json(os.path.join(input_base_path, "params.json"))
num_shards = NUM_SHARDS[model_size]
n_layers = params["n_layers"]
n_heads = params["n_heads"]
n_heads_per_shard = n_heads // num_shards
dim = params["dim"]
dims_per_head = dim // n_heads
base = 10000.0
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
# permute for sliced rotary
def permute(w):
return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
# Load weights
if model_size == "7B":
# Not sharded
# (The sharded implementation would also work, but this is simpler.)
loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu")
else:
# Sharded
loaded = [
torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
for i in range(num_shards)
]
param_count = 0
all_dicts = {}
for layer_i in range(n_layers):
if model_size == "7B":
# Unsharded
state_dict = {
f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
loaded[f"layers.{layer_i}.attention.wq.weight"]
),
f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
loaded[f"layers.{layer_i}.attention.wk.weight"]
),
f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"],
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"],
f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"],
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"],
f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"],
f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"layers.{layer_i}.attention_norm.weight"],
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"layers.{layer_i}.ffn_norm.weight"],
}
else:
# Sharded
# Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share
# the same storage object, saving attention_norm and ffn_norm will save other weights too, which is
# redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned.
state_dict = {
f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][
f"layers.{layer_i}.attention_norm.weight"
].clone(),
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][
f"layers.{layer_i}.ffn_norm.weight"
].clone(),
}
state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
torch.cat(
[
loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim)
for i in range(num_shards)
],
dim=0,
).reshape(dim, dim)
) )
)) state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
converted[f"transformer.h.{layer_idx}.attn.c_proj.weight"] = tr(get_and_remove( torch.cat(
f"layers.{layer_idx}.attention.wo.weight" [
)) loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(n_heads_per_shard, dims_per_head, dim)
# mlp for i in range(num_shards)
converted[f"transformer.h.{layer_idx}.mlp.c_fc1.weight"] = get_and_remove( ],
f"layers.{layer_idx}.feed_forward.w1.weight", transpose=True, dim=0,
).reshape(dim, dim)
)
state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
[
loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(n_heads_per_shard, dims_per_head, dim)
for i in range(num_shards)
],
dim=0,
).reshape(dim, dim)
state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1
)
state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
)
state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1
)
state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
[loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
)
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
all_dicts |= state_dict
if model_size == "7B":
# Unsharded
state_dict = {
"model.embed_tokens.weight": loaded["tok_embeddings.weight"],
"model.norm.weight": loaded["norm.weight"],
"lm_head.weight": loaded["output.weight"],
}
else:
state_dict = {
"model.norm.weight": loaded[0]["norm.weight"],
"model.embed_tokens.weight": torch.cat(
[loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1
),
"lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
}
all_dicts |= state_dict
all_dicts = {k: v.numpy() for k, v in all_dicts.items()}
np.savez(os.path.join(model_path, "llama.npz"), **all_dicts)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_dir",
help="Location of LLaMA weights, which contains tokenizer.model and model folders",
)
parser.add_argument(
"--model_size",
choices=["7B", "13B", "30B", "65B"],
)
parser.add_argument(
"--output_dir",
help="Location to write HF model and tokenizer",
)
args = parser.parse_args()
write_model(
model_path=args.output_dir,
input_base_path=os.path.join(args.input_dir, args.model_size),
model_size=args.model_size,
) )
converted[f"transformer.h.{layer_idx}.mlp.c_proj.weight"] = get_and_remove(
f"layers.{layer_idx}.feed_forward.w2.weight", transpose=True,
)
converted[f"transformer.h.{layer_idx}.mlp.c_fc2.weight"] = get_and_remove(
f"layers.{layer_idx}.feed_forward.w3.weight", transpose=True,
)
# rms norm
converted[f"transformer.h.{layer_idx}.rms_1.scale"] = get_and_remove(f"layers.{layer_idx}.attention_norm.weight")
converted[f"transformer.h.{layer_idx}.rms_2.scale"] = get_and_remove(f"layers.{layer_idx}.ffn_norm.weight")
return converted
def convert_weights(llama_ckpt, *, output_npz: Path = Path("llama.npz"), dtype: str = "float32") -> None:
dt = getattr(torch, dtype, None)
if not isinstance(dt, torch.dtype):
raise ValueError(f"{dtype} is not a valid dtype.")
checkpoint = torch.load(llama_ckpt, map_location="cpu")
converted = convert_state_dict(checkpoint, dtype=dt)
del checkpoint
np.savez(output_npz, **converted)
if __name__ == "__main__": if __name__ == "__main__":
if len(sys.argv) != 2: main()
raise ValueError(f"usage: convert_checkpoint.py ..../LLaMA/7B/consolidated.00.pth")
convert_weights(sys.argv[1])

View File

@ -144,8 +144,14 @@ fn main() -> Result<()> {
let config = Config::config_7b(); let config = Config::config_7b();
let cache = model::Cache::new(!args.no_kv_cache, &config, &device); let cache = model::Cache::new(!args.no_kv_cache, &config, &device);
let (llama, tokenizer_filename) = match args.npy { let (llama, tokenizer_filename) = match args.npy {
Some(_) => { Some(filename) => {
todo!("fix numpy handling if we continue supporting it") let tensors = Tensor::read_npz(filename)?
.into_iter()
.map(|(n, t)| Ok((n, t.to_dtype(DTYPE)?)))
.collect::<Result<std::collections::HashMap<String, Tensor>>>()?;
let vb = VarBuilder::from_tensors(tensors, DTYPE, &device);
let tokenizer = std::path::PathBuf::from("llama-tokenizer.json");
(Llama::load(vb, &cache, &config)?, tokenizer)
} }
None => { None => {
let api = Api::new()?; let api = Api::new()?;

View File

@ -1,15 +1,20 @@
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor}; use candle::{safetensors::SafeTensors, DType, Device, Error, Shape, Tensor};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
struct SafeTensorWithRouting<'a> { // TODO: Maybe we would want the storage to be generic, e.g. with Box<dyn> to avoid too many
// generics.
enum Tensors<'a> {
SafeTensorWithRouting {
routing: HashMap<String, usize>, routing: HashMap<String, usize>,
safetensors: Vec<SafeTensors<'a>>, safetensors: Vec<SafeTensors<'a>>,
},
TensorMap(HashMap<String, Tensor>),
Zeros,
} }
struct TensorData<'a> { struct TensorData<'a> {
// TODO: Make this part generic, probably via some Box<dyn> to avoid too much generics. tensors: Tensors<'a>,
safetensors: Option<SafeTensorWithRouting<'a>>,
pub dtype: DType, pub dtype: DType,
pub device: Device, pub device: Device,
} }
@ -22,12 +27,12 @@ impl<'a> TensorData<'a> {
routing.insert(k.to_string(), index); routing.insert(k.to_string(), index);
} }
} }
let safetensors = SafeTensorWithRouting { let tensors = Tensors::SafeTensorWithRouting {
routing, routing,
safetensors, safetensors,
}; };
Self { Self {
safetensors: Some(safetensors), tensors,
device: device.clone(), device: device.clone(),
dtype, dtype,
} }
@ -35,7 +40,15 @@ impl<'a> TensorData<'a> {
fn zeros(dtype: DType, device: &Device) -> Self { fn zeros(dtype: DType, device: &Device) -> Self {
Self { Self {
safetensors: None, tensors: Tensors::Zeros,
device: device.clone(),
dtype,
}
}
fn from_tensors(tensors: HashMap<String, Tensor>, dtype: DType, device: &Device) -> Self {
Self {
tensors: Tensors::TensorMap(tensors),
device: device.clone(), device: device.clone(),
dtype, dtype,
} }
@ -67,6 +80,14 @@ impl<'a> VarBuilder<'a> {
} }
} }
pub fn from_tensors(ts: HashMap<String, Tensor>, dtype: DType, device: &Device) -> Self {
let data = TensorData::from_tensors(ts, dtype, device);
Self {
data: Arc::new(data),
path: vec![],
}
}
pub fn push_prefix(&self, s: &str) -> Self { pub fn push_prefix(&self, s: &str) -> Self {
let mut path = self.path.clone(); let mut path = self.path.clone();
path.push(s.to_string()); path.push(s.to_string());
@ -94,23 +115,31 @@ impl<'a> VarBuilder<'a> {
pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> candle::Result<Tensor> { pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> candle::Result<Tensor> {
let data = self.data.as_ref(); let data = self.data.as_ref();
let s: Shape = s.into(); let s: Shape = s.into();
match &self.data.safetensors {
None => Tensor::zeros(s, data.dtype, &data.device),
Some(SafeTensorWithRouting {
routing,
safetensors,
}) => {
let path = if self.path.is_empty() { let path = if self.path.is_empty() {
tensor_name.to_string() tensor_name.to_string()
} else { } else {
[&self.path.join("."), tensor_name].join(".") [&self.path.join("."), tensor_name].join(".")
}; };
let tensor = match &self.data.tensors {
Tensors::Zeros => Tensor::zeros(&s, data.dtype, &data.device)?.contiguous()?,
Tensors::TensorMap(ts) => ts
.get(&path)
.ok_or_else(|| Error::CannotFindTensor {
path: path.to_string(),
})?
.clone(),
Tensors::SafeTensorWithRouting {
routing,
safetensors,
} => {
// Unwrap or 0 just to let the proper error flow. // Unwrap or 0 just to let the proper error flow.
let index = routing.get(&path).unwrap_or(&0); let index = routing.get(&path).unwrap_or(&0);
let tensor = safetensors[*index] safetensors[*index]
.tensor(&path, &data.device)? .tensor(&path, &data.device)?
.to_dtype(data.dtype)?; .to_dtype(data.dtype)?
if *tensor.shape() != s { }
};
if tensor.shape() != &s {
Err(candle::Error::UnexpectedShape { Err(candle::Error::UnexpectedShape {
msg: format!("shape mismatch for {path}"), msg: format!("shape mismatch for {path}"),
expected: s, expected: s,
@ -119,6 +148,4 @@ impl<'a> VarBuilder<'a> {
} }
Ok(tensor) Ok(tensor)
} }
}
}
} }