Load the weights for llama.

This commit is contained in:
laurent
2023-06-26 07:23:59 +01:00
parent 7a3101f15f
commit d867155ef2
3 changed files with 29 additions and 9 deletions

View File

@ -8,7 +8,7 @@ from pathlib import Path
def tr(v): def tr(v):
return np.ascontiguousarray(np.transpose(v)) return np.ascontiguousarray(np.transpose(v))
def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch.dtype = torch.float16) -> Dict[str, torch.Tensor]: def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch.dtype = torch.float32) -> Dict[str, torch.Tensor]:
print("start conv") print("start conv")
def get_and_remove(key, transpose=False): def get_and_remove(key, transpose=False):
@ -53,7 +53,7 @@ def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch.dtype =
converted[f"transformer.h.{layer_idx}.rms_2.scale"] = get_and_remove(f"layers.{layer_idx}.ffn_norm.weight") converted[f"transformer.h.{layer_idx}.rms_2.scale"] = get_and_remove(f"layers.{layer_idx}.ffn_norm.weight")
return converted return converted
def convert_weights(llama_ckpt, *, output_npz: Path = Path("llama.npz"), dtype: str = "float16") -> None: def convert_weights(llama_ckpt, *, output_npz: Path = Path("llama.npz"), dtype: str = "float32") -> None:
dt = getattr(torch, dtype, None) dt = getattr(torch, dtype, None)
if not isinstance(dt, torch.dtype): if not isinstance(dt, torch.dtype):
raise ValueError(f"{dtype} is not a valid dtype.") raise ValueError(f"{dtype} is not a valid dtype.")

View File

@ -426,10 +426,19 @@ fn main() -> Result<()> {
.get_ids() .get_ids()
.to_vec(); .to_vec();
println!("loading weights"); let weight_path = std::path::Path::new("llama-f32.npz");
let weights = if weight_path.exists() {
println!("loading weights from {weight_path:?}");
let start_load = std::time::Instant::now(); let start_load = std::time::Instant::now();
let vb = VarBuilder::new::<f32>(); // TODO: load the weights from llama.npz let tensors = Tensor::read_npz(weight_path)?;
println!("loaded weights in {:?}", start_load.elapsed()); println!("loaded weights in {:?}", start_load.elapsed());
let tensors: std::collections::HashMap<String, Tensor> = tensors.into_iter().collect();
Some(tensors)
} else {
println!("cannot find {weight_path:?}, using zero weights");
None
};
let vb = VarBuilder::new::<f32>(weights);
println!("building the model"); println!("building the model");
let config = Config::config_7b(); let config = Config::config_7b();

View File

@ -1,4 +1,6 @@
use candle::{DType, Device, Result, Shape, Tensor, WithDType}; use candle::{DType, Device, Result, Shape, Tensor, WithDType};
use std::collections::HashMap;
use std::sync::Arc;
#[allow(dead_code)] #[allow(dead_code)]
#[derive(Clone)] #[derive(Clone)]
@ -13,6 +15,7 @@ pub struct VarBuilder {
path: Vec<String>, path: Vec<String>,
vars: std::rc::Rc<std::cell::RefCell<Vec<NamedVar>>>, vars: std::rc::Rc<std::cell::RefCell<Vec<NamedVar>>>,
default_dtype: DType, default_dtype: DType,
tensors: Arc<Option<HashMap<String, Tensor>>>,
} }
#[allow(dead_code)] #[allow(dead_code)]
@ -21,12 +24,13 @@ pub struct VarStore {
} }
impl VarBuilder { impl VarBuilder {
pub fn new<B: WithDType>() -> Self { pub fn new<B: WithDType>(tensors: Option<HashMap<String, Tensor>>) -> Self {
let vars = std::rc::Rc::new(std::cell::RefCell::new(vec![])); let vars = std::rc::Rc::new(std::cell::RefCell::new(vec![]));
Self { Self {
path: vec![], path: vec![],
vars, vars,
default_dtype: B::DTYPE, default_dtype: B::DTYPE,
tensors: Arc::new(tensors),
} }
} }
@ -38,13 +42,19 @@ impl VarBuilder {
let shape = shape.into(); let shape = shape.into();
let path = format!("{}.{s}", self.path.join(".")); let path = format!("{}.{s}", self.path.join("."));
let mut vars = self.vars.borrow_mut(); let mut vars = self.vars.borrow_mut();
let parameter = Tensor::zeros(&shape, self.default_dtype, &Device::Cpu); let parameter = match self.tensors.as_ref() {
None => Tensor::zeros(&shape, self.default_dtype, &Device::Cpu)?,
Some(tensors) => match tensors.get(&path) {
Some(tensor) => tensor.clone(),
None => panic!("cannot find tensor for {path}"),
},
};
vars.push(NamedVar { vars.push(NamedVar {
path, path,
dtype: self.default_dtype, dtype: self.default_dtype,
shape, shape,
}); });
parameter Ok(parameter)
} }
pub fn into_store(self) -> VarStore { pub fn into_store(self) -> VarStore {
@ -65,6 +75,7 @@ impl<S: ToString> std::ops::Div<S> for &VarBuilder {
path, path,
vars: self.vars.clone(), vars: self.vars.clone(),
default_dtype: self.default_dtype, default_dtype: self.default_dtype,
tensors: self.tensors.clone(),
} }
} }
} }