mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Resurrect the llama npy support. (#140)
This commit is contained in:
@ -139,6 +139,9 @@ pub enum Error {
|
||||
rhs_stride: Vec<usize>,
|
||||
mnk: (usize, usize, usize),
|
||||
},
|
||||
|
||||
#[error("cannot find tensor {path}")]
|
||||
CannotFindTensor { path: String },
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
@ -1,10 +1,10 @@
|
||||
//! Numpy support for literals.
|
||||
//! Numpy support for tensors.
|
||||
//!
|
||||
//! The spec for the npy format can be found in
|
||||
//! [npy-format](https://docs.scipy.org/doc/numpy-1.14.2/neps/npy-format.html).
|
||||
//! The functions from this module can be used to read literals from npy/npz files
|
||||
//! or write literals to these files. A npy file contains a single literal (unnamed)
|
||||
//! whereas a npz file can contain multiple named literals. npz files are also compressed.
|
||||
//! The functions from this module can be used to read tensors from npy/npz files
|
||||
//! or write tensors to these files. A npy file contains a single tensor (unnamed)
|
||||
//! whereas a npz file can contain multiple named tensors. npz files are also compressed.
|
||||
//!
|
||||
//! These two formats are easy to use in Python using the numpy library.
|
||||
//!
|
||||
@ -232,7 +232,7 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Reads a npy file and return the stored multi-dimensional array as a literal.
|
||||
/// Reads a npy file and return the stored multi-dimensional array as a tensor.
|
||||
pub fn read_npy<T: AsRef<Path>>(path: T) -> Result<Self> {
|
||||
let mut reader = File::open(path.as_ref())?;
|
||||
let header = read_header(&mut reader)?;
|
||||
|
@ -10,3 +10,10 @@ pub fn get_num_threads() -> usize {
|
||||
Some(_) | None => num_cpus::get(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn has_mkl() -> bool {
|
||||
#[cfg(feature = "mkl")]
|
||||
return true;
|
||||
#[cfg(not(feature = "mkl"))]
|
||||
return false;
|
||||
}
|
||||
|
@ -1,68 +1,199 @@
|
||||
# Adapted from https://github.com/Lightning-AI/lit-llama/blob/main/scripts/convert_checkpoint.py
|
||||
import sys
|
||||
# Adapted from:
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py
|
||||
# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import Dict
|
||||
from pathlib import Path
|
||||
|
||||
def tr(v):
|
||||
return np.ascontiguousarray(np.transpose(v))
|
||||
"""
|
||||
Sample usage:
|
||||
|
||||
def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch.dtype = torch.float32) -> Dict[str, torch.Tensor]:
|
||||
print("start conv")
|
||||
```
|
||||
python src/transformers/models/llama/convert_llama_weights_to_hf.py \
|
||||
--input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path
|
||||
```
|
||||
"""
|
||||
|
||||
def get_and_remove(key, transpose=False):
|
||||
v = state_dict[key].to(dtype).numpy()
|
||||
if transpose:
|
||||
v = tr(v)
|
||||
del state_dict[key]
|
||||
return v
|
||||
INTERMEDIATE_SIZE_MAP = {
|
||||
"7B": 11008,
|
||||
"13B": 13824,
|
||||
"30B": 17920,
|
||||
"65B": 22016,
|
||||
}
|
||||
NUM_SHARDS = {
|
||||
"7B": 1,
|
||||
"13B": 2,
|
||||
"30B": 4,
|
||||
"65B": 8,
|
||||
}
|
||||
|
||||
converted = {}
|
||||
converted["transformer.wte.weight"] = get_and_remove("tok_embeddings.weight")
|
||||
converted["lm_head.weight"] = get_and_remove("output.weight", transpose=True)
|
||||
converted["transformer.ln_f.scale"] = get_and_remove("norm.weight")
|
||||
|
||||
for layer_idx in sorted(set([k.split(".")[1] for k in state_dict if k.startswith("layers")])):
|
||||
print(layer_idx)
|
||||
def compute_intermediate_size(n):
|
||||
return int(math.ceil(n * 8 / 3) + 255) // 256 * 256
|
||||
|
||||
# attention
|
||||
# the wq, wk, wv from the FB model are stacked in our model as c_attn
|
||||
converted[f"transformer.h.{layer_idx}.attn.c_attn.weight"] = tr(np.concatenate(
|
||||
(
|
||||
get_and_remove(f"layers.{layer_idx}.attention.wq.weight"),
|
||||
get_and_remove(f"layers.{layer_idx}.attention.wk.weight"),
|
||||
get_and_remove(f"layers.{layer_idx}.attention.wv.weight"),
|
||||
|
||||
def read_json(path):
|
||||
with open(path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def write_json(text, path):
|
||||
with open(path, "w") as f:
|
||||
json.dump(text, f)
|
||||
|
||||
|
||||
def write_model(model_path, input_base_path, model_size):
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
|
||||
params = read_json(os.path.join(input_base_path, "params.json"))
|
||||
num_shards = NUM_SHARDS[model_size]
|
||||
n_layers = params["n_layers"]
|
||||
n_heads = params["n_heads"]
|
||||
n_heads_per_shard = n_heads // num_shards
|
||||
dim = params["dim"]
|
||||
dims_per_head = dim // n_heads
|
||||
base = 10000.0
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
|
||||
|
||||
# permute for sliced rotary
|
||||
def permute(w):
|
||||
return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
|
||||
|
||||
print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
|
||||
# Load weights
|
||||
if model_size == "7B":
|
||||
# Not sharded
|
||||
# (The sharded implementation would also work, but this is simpler.)
|
||||
loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu")
|
||||
else:
|
||||
# Sharded
|
||||
loaded = [
|
||||
torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
|
||||
for i in range(num_shards)
|
||||
]
|
||||
param_count = 0
|
||||
all_dicts = {}
|
||||
for layer_i in range(n_layers):
|
||||
if model_size == "7B":
|
||||
# Unsharded
|
||||
state_dict = {
|
||||
f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
|
||||
loaded[f"layers.{layer_i}.attention.wq.weight"]
|
||||
),
|
||||
f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
|
||||
loaded[f"layers.{layer_i}.attention.wk.weight"]
|
||||
),
|
||||
f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"],
|
||||
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"],
|
||||
f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"],
|
||||
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"],
|
||||
f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"],
|
||||
f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"layers.{layer_i}.attention_norm.weight"],
|
||||
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"layers.{layer_i}.ffn_norm.weight"],
|
||||
}
|
||||
else:
|
||||
# Sharded
|
||||
# Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share
|
||||
# the same storage object, saving attention_norm and ffn_norm will save other weights too, which is
|
||||
# redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned.
|
||||
|
||||
state_dict = {
|
||||
f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][
|
||||
f"layers.{layer_i}.attention_norm.weight"
|
||||
].clone(),
|
||||
f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][
|
||||
f"layers.{layer_i}.ffn_norm.weight"
|
||||
].clone(),
|
||||
}
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
|
||||
torch.cat(
|
||||
[
|
||||
loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim)
|
||||
for i in range(num_shards)
|
||||
],
|
||||
dim=0,
|
||||
).reshape(dim, dim)
|
||||
)
|
||||
))
|
||||
converted[f"transformer.h.{layer_idx}.attn.c_proj.weight"] = tr(get_and_remove(
|
||||
f"layers.{layer_idx}.attention.wo.weight"
|
||||
))
|
||||
# mlp
|
||||
converted[f"transformer.h.{layer_idx}.mlp.c_fc1.weight"] = get_and_remove(
|
||||
f"layers.{layer_idx}.feed_forward.w1.weight", transpose=True,
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
|
||||
torch.cat(
|
||||
[
|
||||
loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(n_heads_per_shard, dims_per_head, dim)
|
||||
for i in range(num_shards)
|
||||
],
|
||||
dim=0,
|
||||
).reshape(dim, dim)
|
||||
)
|
||||
converted[f"transformer.h.{layer_idx}.mlp.c_proj.weight"] = get_and_remove(
|
||||
f"layers.{layer_idx}.feed_forward.w2.weight", transpose=True,
|
||||
)
|
||||
converted[f"transformer.h.{layer_idx}.mlp.c_fc2.weight"] = get_and_remove(
|
||||
f"layers.{layer_idx}.feed_forward.w3.weight", transpose=True,
|
||||
)
|
||||
# rms norm
|
||||
converted[f"transformer.h.{layer_idx}.rms_1.scale"] = get_and_remove(f"layers.{layer_idx}.attention_norm.weight")
|
||||
converted[f"transformer.h.{layer_idx}.rms_2.scale"] = get_and_remove(f"layers.{layer_idx}.ffn_norm.weight")
|
||||
return converted
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
|
||||
[
|
||||
loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(n_heads_per_shard, dims_per_head, dim)
|
||||
for i in range(num_shards)
|
||||
],
|
||||
dim=0,
|
||||
).reshape(dim, dim)
|
||||
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1
|
||||
)
|
||||
state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
|
||||
[loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
|
||||
)
|
||||
|
||||
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
|
||||
all_dicts |= state_dict
|
||||
|
||||
if model_size == "7B":
|
||||
# Unsharded
|
||||
state_dict = {
|
||||
"model.embed_tokens.weight": loaded["tok_embeddings.weight"],
|
||||
"model.norm.weight": loaded["norm.weight"],
|
||||
"lm_head.weight": loaded["output.weight"],
|
||||
}
|
||||
else:
|
||||
state_dict = {
|
||||
"model.norm.weight": loaded[0]["norm.weight"],
|
||||
"model.embed_tokens.weight": torch.cat(
|
||||
[loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1
|
||||
),
|
||||
"lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
|
||||
}
|
||||
all_dicts |= state_dict
|
||||
all_dicts = {k: v.numpy() for k, v in all_dicts.items()}
|
||||
np.savez(os.path.join(model_path, "llama.npz"), **all_dicts)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--input_dir",
|
||||
help="Location of LLaMA weights, which contains tokenizer.model and model folders",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_size",
|
||||
choices=["7B", "13B", "30B", "65B"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
help="Location to write HF model and tokenizer",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
write_model(
|
||||
model_path=args.output_dir,
|
||||
input_base_path=os.path.join(args.input_dir, args.model_size),
|
||||
model_size=args.model_size,
|
||||
)
|
||||
|
||||
def convert_weights(llama_ckpt, *, output_npz: Path = Path("llama.npz"), dtype: str = "float32") -> None:
|
||||
dt = getattr(torch, dtype, None)
|
||||
if not isinstance(dt, torch.dtype):
|
||||
raise ValueError(f"{dtype} is not a valid dtype.")
|
||||
checkpoint = torch.load(llama_ckpt, map_location="cpu")
|
||||
converted = convert_state_dict(checkpoint, dtype=dt)
|
||||
del checkpoint
|
||||
np.savez(output_npz, **converted)
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 2:
|
||||
raise ValueError(f"usage: convert_checkpoint.py ..../LLaMA/7B/consolidated.00.pth")
|
||||
convert_weights(sys.argv[1])
|
||||
main()
|
||||
|
@ -144,8 +144,14 @@ fn main() -> Result<()> {
|
||||
let config = Config::config_7b();
|
||||
let cache = model::Cache::new(!args.no_kv_cache, &config, &device);
|
||||
let (llama, tokenizer_filename) = match args.npy {
|
||||
Some(_) => {
|
||||
todo!("fix numpy handling if we continue supporting it")
|
||||
Some(filename) => {
|
||||
let tensors = Tensor::read_npz(filename)?
|
||||
.into_iter()
|
||||
.map(|(n, t)| Ok((n, t.to_dtype(DTYPE)?)))
|
||||
.collect::<Result<std::collections::HashMap<String, Tensor>>>()?;
|
||||
let vb = VarBuilder::from_tensors(tensors, DTYPE, &device);
|
||||
let tokenizer = std::path::PathBuf::from("llama-tokenizer.json");
|
||||
(Llama::load(vb, &cache, &config)?, tokenizer)
|
||||
}
|
||||
None => {
|
||||
let api = Api::new()?;
|
||||
|
@ -1,15 +1,20 @@
|
||||
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
|
||||
use candle::{safetensors::SafeTensors, DType, Device, Error, Shape, Tensor};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
struct SafeTensorWithRouting<'a> {
|
||||
routing: HashMap<String, usize>,
|
||||
safetensors: Vec<SafeTensors<'a>>,
|
||||
// TODO: Maybe we would want the storage to be generic, e.g. with Box<dyn> to avoid too many
|
||||
// generics.
|
||||
enum Tensors<'a> {
|
||||
SafeTensorWithRouting {
|
||||
routing: HashMap<String, usize>,
|
||||
safetensors: Vec<SafeTensors<'a>>,
|
||||
},
|
||||
TensorMap(HashMap<String, Tensor>),
|
||||
Zeros,
|
||||
}
|
||||
|
||||
struct TensorData<'a> {
|
||||
// TODO: Make this part generic, probably via some Box<dyn> to avoid too much generics.
|
||||
safetensors: Option<SafeTensorWithRouting<'a>>,
|
||||
tensors: Tensors<'a>,
|
||||
pub dtype: DType,
|
||||
pub device: Device,
|
||||
}
|
||||
@ -22,12 +27,12 @@ impl<'a> TensorData<'a> {
|
||||
routing.insert(k.to_string(), index);
|
||||
}
|
||||
}
|
||||
let safetensors = SafeTensorWithRouting {
|
||||
let tensors = Tensors::SafeTensorWithRouting {
|
||||
routing,
|
||||
safetensors,
|
||||
};
|
||||
Self {
|
||||
safetensors: Some(safetensors),
|
||||
tensors,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
}
|
||||
@ -35,7 +40,15 @@ impl<'a> TensorData<'a> {
|
||||
|
||||
fn zeros(dtype: DType, device: &Device) -> Self {
|
||||
Self {
|
||||
safetensors: None,
|
||||
tensors: Tensors::Zeros,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
|
||||
fn from_tensors(tensors: HashMap<String, Tensor>, dtype: DType, device: &Device) -> Self {
|
||||
Self {
|
||||
tensors: Tensors::TensorMap(tensors),
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
}
|
||||
@ -67,6 +80,14 @@ impl<'a> VarBuilder<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_tensors(ts: HashMap<String, Tensor>, dtype: DType, device: &Device) -> Self {
|
||||
let data = TensorData::from_tensors(ts, dtype, device);
|
||||
Self {
|
||||
data: Arc::new(data),
|
||||
path: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push_prefix(&self, s: &str) -> Self {
|
||||
let mut path = self.path.clone();
|
||||
path.push(s.to_string());
|
||||
@ -94,31 +115,37 @@ impl<'a> VarBuilder<'a> {
|
||||
pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> candle::Result<Tensor> {
|
||||
let data = self.data.as_ref();
|
||||
let s: Shape = s.into();
|
||||
match &self.data.safetensors {
|
||||
None => Tensor::zeros(s, data.dtype, &data.device),
|
||||
Some(SafeTensorWithRouting {
|
||||
let path = if self.path.is_empty() {
|
||||
tensor_name.to_string()
|
||||
} else {
|
||||
[&self.path.join("."), tensor_name].join(".")
|
||||
};
|
||||
let tensor = match &self.data.tensors {
|
||||
Tensors::Zeros => Tensor::zeros(&s, data.dtype, &data.device)?.contiguous()?,
|
||||
Tensors::TensorMap(ts) => ts
|
||||
.get(&path)
|
||||
.ok_or_else(|| Error::CannotFindTensor {
|
||||
path: path.to_string(),
|
||||
})?
|
||||
.clone(),
|
||||
Tensors::SafeTensorWithRouting {
|
||||
routing,
|
||||
safetensors,
|
||||
}) => {
|
||||
let path = if self.path.is_empty() {
|
||||
tensor_name.to_string()
|
||||
} else {
|
||||
[&self.path.join("."), tensor_name].join(".")
|
||||
};
|
||||
} => {
|
||||
// Unwrap or 0 just to let the proper error flow.
|
||||
let index = routing.get(&path).unwrap_or(&0);
|
||||
let tensor = safetensors[*index]
|
||||
safetensors[*index]
|
||||
.tensor(&path, &data.device)?
|
||||
.to_dtype(data.dtype)?;
|
||||
if *tensor.shape() != s {
|
||||
Err(candle::Error::UnexpectedShape {
|
||||
msg: format!("shape mismatch for {path}"),
|
||||
expected: s,
|
||||
got: tensor.shape().clone(),
|
||||
})?
|
||||
}
|
||||
Ok(tensor)
|
||||
.to_dtype(data.dtype)?
|
||||
}
|
||||
};
|
||||
if tensor.shape() != &s {
|
||||
Err(candle::Error::UnexpectedShape {
|
||||
msg: format!("shape mismatch for {path}"),
|
||||
expected: s,
|
||||
got: tensor.shape().clone(),
|
||||
})?
|
||||
}
|
||||
Ok(tensor)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user