From 37cad858698e519435c916421cc97b4f6b7fe53e Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 11 Jul 2023 19:32:10 +0100 Subject: [PATCH] Resurrect the llama npy support. (#140) --- candle-core/src/error.rs | 3 + candle-core/src/npy.rs | 10 +- candle-core/src/utils.rs | 7 + .../examples/llama/convert_checkpoint.py | 241 ++++++++++++++---- candle-examples/examples/llama/main.rs | 10 +- candle-nn/src/var_builder.rs | 83 ++++-- 6 files changed, 264 insertions(+), 90 deletions(-) diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index caad3e1f..27fd11bb 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -139,6 +139,9 @@ pub enum Error { rhs_stride: Vec, mnk: (usize, usize, usize), }, + + #[error("cannot find tensor {path}")] + CannotFindTensor { path: String }, } pub type Result = std::result::Result; diff --git a/candle-core/src/npy.rs b/candle-core/src/npy.rs index c0608519..7cf6d381 100644 --- a/candle-core/src/npy.rs +++ b/candle-core/src/npy.rs @@ -1,10 +1,10 @@ -//! Numpy support for literals. +//! Numpy support for tensors. //! //! The spec for the npy format can be found in //! [npy-format](https://docs.scipy.org/doc/numpy-1.14.2/neps/npy-format.html). -//! The functions from this module can be used to read literals from npy/npz files -//! or write literals to these files. A npy file contains a single literal (unnamed) -//! whereas a npz file can contain multiple named literals. npz files are also compressed. +//! The functions from this module can be used to read tensors from npy/npz files +//! or write tensors to these files. A npy file contains a single tensor (unnamed) +//! whereas a npz file can contain multiple named tensors. npz files are also compressed. //! //! These two formats are easy to use in Python using the numpy library. //! @@ -232,7 +232,7 @@ impl Tensor { } } - /// Reads a npy file and return the stored multi-dimensional array as a literal. + /// Reads a npy file and return the stored multi-dimensional array as a tensor. pub fn read_npy>(path: T) -> Result { let mut reader = File::open(path.as_ref())?; let header = read_header(&mut reader)?; diff --git a/candle-core/src/utils.rs b/candle-core/src/utils.rs index 4b1e941b..b5621e56 100644 --- a/candle-core/src/utils.rs +++ b/candle-core/src/utils.rs @@ -10,3 +10,10 @@ pub fn get_num_threads() -> usize { Some(_) | None => num_cpus::get(), } } + +pub fn has_mkl() -> bool { + #[cfg(feature = "mkl")] + return true; + #[cfg(not(feature = "mkl"))] + return false; +} diff --git a/candle-examples/examples/llama/convert_checkpoint.py b/candle-examples/examples/llama/convert_checkpoint.py index 245c167c..1b44a04a 100644 --- a/candle-examples/examples/llama/convert_checkpoint.py +++ b/candle-examples/examples/llama/convert_checkpoint.py @@ -1,68 +1,199 @@ -# Adapted from https://github.com/Lightning-AI/lit-llama/blob/main/scripts/convert_checkpoint.py -import sys +# Adapted from: +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py +# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved. +import argparse +import gc +import json +import math +import os +import shutil +import warnings + import torch import numpy as np -from typing import Dict -from pathlib import Path -def tr(v): - return np.ascontiguousarray(np.transpose(v)) +""" +Sample usage: -def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch.dtype = torch.float32) -> Dict[str, torch.Tensor]: - print("start conv") +``` +python src/transformers/models/llama/convert_llama_weights_to_hf.py \ + --input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path +``` +""" - def get_and_remove(key, transpose=False): - v = state_dict[key].to(dtype).numpy() - if transpose: - v = tr(v) - del state_dict[key] - return v +INTERMEDIATE_SIZE_MAP = { + "7B": 11008, + "13B": 13824, + "30B": 17920, + "65B": 22016, +} +NUM_SHARDS = { + "7B": 1, + "13B": 2, + "30B": 4, + "65B": 8, +} - converted = {} - converted["transformer.wte.weight"] = get_and_remove("tok_embeddings.weight") - converted["lm_head.weight"] = get_and_remove("output.weight", transpose=True) - converted["transformer.ln_f.scale"] = get_and_remove("norm.weight") - for layer_idx in sorted(set([k.split(".")[1] for k in state_dict if k.startswith("layers")])): - print(layer_idx) +def compute_intermediate_size(n): + return int(math.ceil(n * 8 / 3) + 255) // 256 * 256 - # attention - # the wq, wk, wv from the FB model are stacked in our model as c_attn - converted[f"transformer.h.{layer_idx}.attn.c_attn.weight"] = tr(np.concatenate( - ( - get_and_remove(f"layers.{layer_idx}.attention.wq.weight"), - get_and_remove(f"layers.{layer_idx}.attention.wk.weight"), - get_and_remove(f"layers.{layer_idx}.attention.wv.weight"), + +def read_json(path): + with open(path, "r") as f: + return json.load(f) + + +def write_json(text, path): + with open(path, "w") as f: + json.dump(text, f) + + +def write_model(model_path, input_base_path, model_size): + os.makedirs(model_path, exist_ok=True) + + params = read_json(os.path.join(input_base_path, "params.json")) + num_shards = NUM_SHARDS[model_size] + n_layers = params["n_layers"] + n_heads = params["n_heads"] + n_heads_per_shard = n_heads // num_shards + dim = params["dim"] + dims_per_head = dim // n_heads + base = 10000.0 + inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) + + # permute for sliced rotary + def permute(w): + return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim) + + print(f"Fetching all parameters from the checkpoint at {input_base_path}.") + # Load weights + if model_size == "7B": + # Not sharded + # (The sharded implementation would also work, but this is simpler.) + loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu") + else: + # Sharded + loaded = [ + torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu") + for i in range(num_shards) + ] + param_count = 0 + all_dicts = {} + for layer_i in range(n_layers): + if model_size == "7B": + # Unsharded + state_dict = { + f"model.layers.{layer_i}.self_attn.q_proj.weight": permute( + loaded[f"layers.{layer_i}.attention.wq.weight"] + ), + f"model.layers.{layer_i}.self_attn.k_proj.weight": permute( + loaded[f"layers.{layer_i}.attention.wk.weight"] + ), + f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"], + f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"], + f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"], + f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"], + f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"], + f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"layers.{layer_i}.attention_norm.weight"], + f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"layers.{layer_i}.ffn_norm.weight"], + } + else: + # Sharded + # Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share + # the same storage object, saving attention_norm and ffn_norm will save other weights too, which is + # redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned. + + state_dict = { + f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][ + f"layers.{layer_i}.attention_norm.weight" + ].clone(), + f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][ + f"layers.{layer_i}.ffn_norm.weight" + ].clone(), + } + state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute( + torch.cat( + [ + loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim) + for i in range(num_shards) + ], + dim=0, + ).reshape(dim, dim) ) - )) - converted[f"transformer.h.{layer_idx}.attn.c_proj.weight"] = tr(get_and_remove( - f"layers.{layer_idx}.attention.wo.weight" - )) - # mlp - converted[f"transformer.h.{layer_idx}.mlp.c_fc1.weight"] = get_and_remove( - f"layers.{layer_idx}.feed_forward.w1.weight", transpose=True, + state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute( + torch.cat( + [ + loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(n_heads_per_shard, dims_per_head, dim) + for i in range(num_shards) + ], + dim=0, + ).reshape(dim, dim) ) - converted[f"transformer.h.{layer_idx}.mlp.c_proj.weight"] = get_and_remove( - f"layers.{layer_idx}.feed_forward.w2.weight", transpose=True, - ) - converted[f"transformer.h.{layer_idx}.mlp.c_fc2.weight"] = get_and_remove( - f"layers.{layer_idx}.feed_forward.w3.weight", transpose=True, - ) - # rms norm - converted[f"transformer.h.{layer_idx}.rms_1.scale"] = get_and_remove(f"layers.{layer_idx}.attention_norm.weight") - converted[f"transformer.h.{layer_idx}.rms_2.scale"] = get_and_remove(f"layers.{layer_idx}.ffn_norm.weight") - return converted + state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( + [ + loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(n_heads_per_shard, dims_per_head, dim) + for i in range(num_shards) + ], + dim=0, + ).reshape(dim, dim) + + state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( + [loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1 + ) + state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( + [loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0 + ) + state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( + [loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1 + ) + state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( + [loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0 + ) + + state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq + all_dicts |= state_dict + + if model_size == "7B": + # Unsharded + state_dict = { + "model.embed_tokens.weight": loaded["tok_embeddings.weight"], + "model.norm.weight": loaded["norm.weight"], + "lm_head.weight": loaded["output.weight"], + } + else: + state_dict = { + "model.norm.weight": loaded[0]["norm.weight"], + "model.embed_tokens.weight": torch.cat( + [loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1 + ), + "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0), + } + all_dicts |= state_dict + all_dicts = {k: v.numpy() for k, v in all_dicts.items()} + np.savez(os.path.join(model_path, "llama.npz"), **all_dicts) + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_dir", + help="Location of LLaMA weights, which contains tokenizer.model and model folders", + ) + parser.add_argument( + "--model_size", + choices=["7B", "13B", "30B", "65B"], + ) + parser.add_argument( + "--output_dir", + help="Location to write HF model and tokenizer", + ) + args = parser.parse_args() + write_model( + model_path=args.output_dir, + input_base_path=os.path.join(args.input_dir, args.model_size), + model_size=args.model_size, + ) -def convert_weights(llama_ckpt, *, output_npz: Path = Path("llama.npz"), dtype: str = "float32") -> None: - dt = getattr(torch, dtype, None) - if not isinstance(dt, torch.dtype): - raise ValueError(f"{dtype} is not a valid dtype.") - checkpoint = torch.load(llama_ckpt, map_location="cpu") - converted = convert_state_dict(checkpoint, dtype=dt) - del checkpoint - np.savez(output_npz, **converted) if __name__ == "__main__": - if len(sys.argv) != 2: - raise ValueError(f"usage: convert_checkpoint.py ..../LLaMA/7B/consolidated.00.pth") - convert_weights(sys.argv[1]) + main() diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index 75cea7ff..6ac4458e 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -144,8 +144,14 @@ fn main() -> Result<()> { let config = Config::config_7b(); let cache = model::Cache::new(!args.no_kv_cache, &config, &device); let (llama, tokenizer_filename) = match args.npy { - Some(_) => { - todo!("fix numpy handling if we continue supporting it") + Some(filename) => { + let tensors = Tensor::read_npz(filename)? + .into_iter() + .map(|(n, t)| Ok((n, t.to_dtype(DTYPE)?))) + .collect::>>()?; + let vb = VarBuilder::from_tensors(tensors, DTYPE, &device); + let tokenizer = std::path::PathBuf::from("llama-tokenizer.json"); + (Llama::load(vb, &cache, &config)?, tokenizer) } None => { let api = Api::new()?; diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index d71b5822..6d79bddd 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -1,15 +1,20 @@ -use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor}; +use candle::{safetensors::SafeTensors, DType, Device, Error, Shape, Tensor}; use std::collections::HashMap; use std::sync::Arc; -struct SafeTensorWithRouting<'a> { - routing: HashMap, - safetensors: Vec>, +// TODO: Maybe we would want the storage to be generic, e.g. with Box to avoid too many +// generics. +enum Tensors<'a> { + SafeTensorWithRouting { + routing: HashMap, + safetensors: Vec>, + }, + TensorMap(HashMap), + Zeros, } struct TensorData<'a> { - // TODO: Make this part generic, probably via some Box to avoid too much generics. - safetensors: Option>, + tensors: Tensors<'a>, pub dtype: DType, pub device: Device, } @@ -22,12 +27,12 @@ impl<'a> TensorData<'a> { routing.insert(k.to_string(), index); } } - let safetensors = SafeTensorWithRouting { + let tensors = Tensors::SafeTensorWithRouting { routing, safetensors, }; Self { - safetensors: Some(safetensors), + tensors, device: device.clone(), dtype, } @@ -35,7 +40,15 @@ impl<'a> TensorData<'a> { fn zeros(dtype: DType, device: &Device) -> Self { Self { - safetensors: None, + tensors: Tensors::Zeros, + device: device.clone(), + dtype, + } + } + + fn from_tensors(tensors: HashMap, dtype: DType, device: &Device) -> Self { + Self { + tensors: Tensors::TensorMap(tensors), device: device.clone(), dtype, } @@ -67,6 +80,14 @@ impl<'a> VarBuilder<'a> { } } + pub fn from_tensors(ts: HashMap, dtype: DType, device: &Device) -> Self { + let data = TensorData::from_tensors(ts, dtype, device); + Self { + data: Arc::new(data), + path: vec![], + } + } + pub fn push_prefix(&self, s: &str) -> Self { let mut path = self.path.clone(); path.push(s.to_string()); @@ -94,31 +115,37 @@ impl<'a> VarBuilder<'a> { pub fn get>(&self, s: S, tensor_name: &str) -> candle::Result { let data = self.data.as_ref(); let s: Shape = s.into(); - match &self.data.safetensors { - None => Tensor::zeros(s, data.dtype, &data.device), - Some(SafeTensorWithRouting { + let path = if self.path.is_empty() { + tensor_name.to_string() + } else { + [&self.path.join("."), tensor_name].join(".") + }; + let tensor = match &self.data.tensors { + Tensors::Zeros => Tensor::zeros(&s, data.dtype, &data.device)?.contiguous()?, + Tensors::TensorMap(ts) => ts + .get(&path) + .ok_or_else(|| Error::CannotFindTensor { + path: path.to_string(), + })? + .clone(), + Tensors::SafeTensorWithRouting { routing, safetensors, - }) => { - let path = if self.path.is_empty() { - tensor_name.to_string() - } else { - [&self.path.join("."), tensor_name].join(".") - }; + } => { // Unwrap or 0 just to let the proper error flow. let index = routing.get(&path).unwrap_or(&0); - let tensor = safetensors[*index] + safetensors[*index] .tensor(&path, &data.device)? - .to_dtype(data.dtype)?; - if *tensor.shape() != s { - Err(candle::Error::UnexpectedShape { - msg: format!("shape mismatch for {path}"), - expected: s, - got: tensor.shape().clone(), - })? - } - Ok(tensor) + .to_dtype(data.dtype)? } + }; + if tensor.shape() != &s { + Err(candle::Error::UnexpectedShape { + msg: format!("shape mismatch for {path}"), + expected: s, + got: tensor.shape().clone(), + })? } + Ok(tensor) } }