diff --git a/examples/llama/convert_checkpoint.py b/examples/llama/convert_checkpoint.py index 631aab35..245c167c 100644 --- a/examples/llama/convert_checkpoint.py +++ b/examples/llama/convert_checkpoint.py @@ -8,7 +8,7 @@ from pathlib import Path def tr(v): return np.ascontiguousarray(np.transpose(v)) -def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch.dtype = torch.float16) -> Dict[str, torch.Tensor]: +def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch.dtype = torch.float32) -> Dict[str, torch.Tensor]: print("start conv") def get_and_remove(key, transpose=False): @@ -53,7 +53,7 @@ def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch.dtype = converted[f"transformer.h.{layer_idx}.rms_2.scale"] = get_and_remove(f"layers.{layer_idx}.ffn_norm.weight") return converted -def convert_weights(llama_ckpt, *, output_npz: Path = Path("llama.npz"), dtype: str = "float16") -> None: +def convert_weights(llama_ckpt, *, output_npz: Path = Path("llama.npz"), dtype: str = "float32") -> None: dt = getattr(torch, dtype, None) if not isinstance(dt, torch.dtype): raise ValueError(f"{dtype} is not a valid dtype.") diff --git a/examples/llama/main.rs b/examples/llama/main.rs index df267002..d0dd0d19 100644 --- a/examples/llama/main.rs +++ b/examples/llama/main.rs @@ -426,10 +426,19 @@ fn main() -> Result<()> { .get_ids() .to_vec(); - println!("loading weights"); - let start_load = std::time::Instant::now(); - let vb = VarBuilder::new::(); // TODO: load the weights from llama.npz - println!("loaded weights in {:?}", start_load.elapsed()); + let weight_path = std::path::Path::new("llama-f32.npz"); + let weights = if weight_path.exists() { + println!("loading weights from {weight_path:?}"); + let start_load = std::time::Instant::now(); + let tensors = Tensor::read_npz(weight_path)?; + println!("loaded weights in {:?}", start_load.elapsed()); + let tensors: std::collections::HashMap = tensors.into_iter().collect(); + Some(tensors) + } else { + println!("cannot find {weight_path:?}, using zero weights"); + None + }; + let vb = VarBuilder::new::(weights); println!("building the model"); let config = Config::config_7b(); diff --git a/examples/llama/var_store.rs b/examples/llama/var_store.rs index cff1e37a..fb3d0c61 100644 --- a/examples/llama/var_store.rs +++ b/examples/llama/var_store.rs @@ -1,4 +1,6 @@ use candle::{DType, Device, Result, Shape, Tensor, WithDType}; +use std::collections::HashMap; +use std::sync::Arc; #[allow(dead_code)] #[derive(Clone)] @@ -13,6 +15,7 @@ pub struct VarBuilder { path: Vec, vars: std::rc::Rc>>, default_dtype: DType, + tensors: Arc>>, } #[allow(dead_code)] @@ -21,12 +24,13 @@ pub struct VarStore { } impl VarBuilder { - pub fn new() -> Self { + pub fn new(tensors: Option>) -> Self { let vars = std::rc::Rc::new(std::cell::RefCell::new(vec![])); Self { path: vec![], vars, default_dtype: B::DTYPE, + tensors: Arc::new(tensors), } } @@ -38,13 +42,19 @@ impl VarBuilder { let shape = shape.into(); let path = format!("{}.{s}", self.path.join(".")); let mut vars = self.vars.borrow_mut(); - let parameter = Tensor::zeros(&shape, self.default_dtype, &Device::Cpu); + let parameter = match self.tensors.as_ref() { + None => Tensor::zeros(&shape, self.default_dtype, &Device::Cpu)?, + Some(tensors) => match tensors.get(&path) { + Some(tensor) => tensor.clone(), + None => panic!("cannot find tensor for {path}"), + }, + }; vars.push(NamedVar { path, dtype: self.default_dtype, shape, }); - parameter + Ok(parameter) } pub fn into_store(self) -> VarStore { @@ -65,6 +75,7 @@ impl std::ops::Div for &VarBuilder { path, vars: self.vars.clone(), default_dtype: self.default_dtype, + tensors: self.tensors.clone(), } } }