From 679b6987b67c6c93d1e878dfec67fe5e06feef4a Mon Sep 17 00:00:00 2001 From: laurent Date: Fri, 30 Jun 2023 16:42:53 +0100 Subject: [PATCH] Early conversion for the llama weights. --- candle-core/examples/llama/var_store.rs | 60 +++++++------------------ candle-core/examples/llama/weights.rs | 4 +- 2 files changed, 19 insertions(+), 45 deletions(-) diff --git a/candle-core/examples/llama/var_store.rs b/candle-core/examples/llama/var_store.rs index 0106e941..1a22bd89 100644 --- a/candle-core/examples/llama/var_store.rs +++ b/candle-core/examples/llama/var_store.rs @@ -1,7 +1,7 @@ use super::*; -use candle::{DType, Device, Result, Shape, Tensor, WithDType}; +use candle::{DType, Device, Result, Shape, Tensor}; use std::collections::HashMap; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; #[allow(dead_code)] #[derive(Clone)] @@ -14,51 +14,28 @@ struct NamedVar { #[derive(Clone)] pub struct VarBuilder { path: Vec, - vars: std::rc::Rc>>, - default_dtype: DType, default_device: Device, - tensors: Arc>>, -} - -#[allow(dead_code)] -pub struct VarStore { - vars: Vec, + tensors: Arc>>, } impl VarBuilder { - pub fn new(device: &Device, tensors: Option>) -> Self { - let vars = std::rc::Rc::new(std::cell::RefCell::new(vec![])); + pub fn new(device: &Device, tensors: HashMap) -> Self { Self { path: vec![], - vars, - default_dtype: B::DTYPE, - tensors: Arc::new(tensors), + tensors: Arc::new(Mutex::new(tensors)), default_device: device.clone(), } } - pub fn len(&self) -> usize { - self.vars.borrow().len() - } - - pub fn var(&self, s: &str) -> Result { + pub fn get_and_remove(&self, s: &str) -> Result { let path = format!("{}.{s}", self.path.join(".")); - let parameter = match self.tensors.as_ref() { - None => panic!("Cannot find tensors"), - Some(tensors) => match tensors.get(&path) { - Some(tensor) => tensor.to_device(&self.default_device)?, - None => panic!("cannot find tensor for {path}"), - }, + let mut tensors = self.tensors.as_ref().lock().unwrap(); + let parameter = match tensors.remove(&path) { + Some(tensor) => tensor.to_device(&self.default_device)?, + None => panic!("cannot find tensor for {path}"), }; Ok(parameter) } - - pub fn into_store(self) -> VarStore { - let vars = self.vars.borrow(); - VarStore { - vars: vars.to_vec(), - } - } } impl std::ops::Div for &VarBuilder { @@ -69,8 +46,6 @@ impl std::ops::Div for &VarBuilder { path.push(rhs.to_string()); VarBuilder { path, - vars: self.vars.clone(), - default_dtype: self.default_dtype, default_device: self.default_device.clone(), tensors: self.tensors.clone(), } @@ -87,21 +62,21 @@ impl std::ops::Div for VarBuilder { impl Embedding { fn load_npy(vb: VarBuilder) -> Result { - let embeddings = vb.var("weight")?; + let embeddings = vb.get_and_remove("weight")?.to_dtype(DTYPE)?; Ok(Self { embeddings }) } } impl Linear { fn load_npy(vb: VarBuilder) -> Result { - let weight = vb.var("weight")?.t()?; + let weight = vb.get_and_remove("weight")?.to_dtype(DTYPE)?.t()?; Ok(Self { weight }) } } impl RmsNorm { fn load_npy(vb: VarBuilder) -> Result { - let scale = vb.var("scale")?; + let scale = vb.get_and_remove("scale")?.to_dtype(DTYPE)?; Ok(Self::new(scale)) } } @@ -144,7 +119,7 @@ impl Llama { filename: &str, cache: &Cache, config: &Config, - ) -> Result { + ) -> anyhow::Result { let weight_path = std::path::Path::new(filename); let weights = if weight_path.exists() { println!("loading weights from {weight_path:?}"); @@ -152,12 +127,11 @@ impl Llama { let tensors = Tensor::read_npz(weight_path)?; println!("loaded weights in {:?}", start_load.elapsed()); let tensors: std::collections::HashMap = tensors.into_iter().collect(); - Some(tensors) + tensors } else { - println!("cannot find {weight_path:?}, using zero weights"); - None + anyhow::bail!("cannot find {weight_path:?}") }; - let vb = VarBuilder::new::(device, weights); + let vb = VarBuilder::new(device, weights); let wte = Embedding::load_npy(&vb / "transformer" / "wte")?; let lm_head = Linear::load_npy(&vb / "lm_head")?; diff --git a/candle-core/examples/llama/weights.rs b/candle-core/examples/llama/weights.rs index 73609d51..5eff8e21 100644 --- a/candle-core/examples/llama/weights.rs +++ b/candle-core/examples/llama/weights.rs @@ -18,7 +18,7 @@ fn convert(view: TensorView<'_>, device: &Device) -> Result { // was correctly aligned. let data: &[f16] = unsafe { std::slice::from_raw_parts(v.as_ptr() as *const f16, v.len() / 2) }; - Tensor::from_slice(data, view.shape(), device) + Tensor::from_slice(data, view.shape(), device)?.to_dtype(DTYPE) } else { let mut c = Vec::with_capacity(v.len() / 2); let mut i = 0; @@ -26,7 +26,7 @@ fn convert(view: TensorView<'_>, device: &Device) -> Result { c.push(f16::from_le_bytes([v[i], v[i + 1]])); i += 2; } - Tensor::from_slice(&c, view.shape(), device) + Tensor::from_slice(&c, view.shape(), device)?.to_dtype(DTYPE) } } dt => todo!("Unhandled dtype {dt:?}"),