mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Load the weights for llama.
This commit is contained in:
@ -8,7 +8,7 @@ from pathlib import Path
|
|||||||
def tr(v):
|
def tr(v):
|
||||||
return np.ascontiguousarray(np.transpose(v))
|
return np.ascontiguousarray(np.transpose(v))
|
||||||
|
|
||||||
def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch.dtype = torch.float16) -> Dict[str, torch.Tensor]:
|
def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch.dtype = torch.float32) -> Dict[str, torch.Tensor]:
|
||||||
print("start conv")
|
print("start conv")
|
||||||
|
|
||||||
def get_and_remove(key, transpose=False):
|
def get_and_remove(key, transpose=False):
|
||||||
@ -53,7 +53,7 @@ def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch.dtype =
|
|||||||
converted[f"transformer.h.{layer_idx}.rms_2.scale"] = get_and_remove(f"layers.{layer_idx}.ffn_norm.weight")
|
converted[f"transformer.h.{layer_idx}.rms_2.scale"] = get_and_remove(f"layers.{layer_idx}.ffn_norm.weight")
|
||||||
return converted
|
return converted
|
||||||
|
|
||||||
def convert_weights(llama_ckpt, *, output_npz: Path = Path("llama.npz"), dtype: str = "float16") -> None:
|
def convert_weights(llama_ckpt, *, output_npz: Path = Path("llama.npz"), dtype: str = "float32") -> None:
|
||||||
dt = getattr(torch, dtype, None)
|
dt = getattr(torch, dtype, None)
|
||||||
if not isinstance(dt, torch.dtype):
|
if not isinstance(dt, torch.dtype):
|
||||||
raise ValueError(f"{dtype} is not a valid dtype.")
|
raise ValueError(f"{dtype} is not a valid dtype.")
|
||||||
|
@ -426,10 +426,19 @@ fn main() -> Result<()> {
|
|||||||
.get_ids()
|
.get_ids()
|
||||||
.to_vec();
|
.to_vec();
|
||||||
|
|
||||||
println!("loading weights");
|
let weight_path = std::path::Path::new("llama-f32.npz");
|
||||||
|
let weights = if weight_path.exists() {
|
||||||
|
println!("loading weights from {weight_path:?}");
|
||||||
let start_load = std::time::Instant::now();
|
let start_load = std::time::Instant::now();
|
||||||
let vb = VarBuilder::new::<f32>(); // TODO: load the weights from llama.npz
|
let tensors = Tensor::read_npz(weight_path)?;
|
||||||
println!("loaded weights in {:?}", start_load.elapsed());
|
println!("loaded weights in {:?}", start_load.elapsed());
|
||||||
|
let tensors: std::collections::HashMap<String, Tensor> = tensors.into_iter().collect();
|
||||||
|
Some(tensors)
|
||||||
|
} else {
|
||||||
|
println!("cannot find {weight_path:?}, using zero weights");
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let vb = VarBuilder::new::<f32>(weights);
|
||||||
|
|
||||||
println!("building the model");
|
println!("building the model");
|
||||||
let config = Config::config_7b();
|
let config = Config::config_7b();
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
use candle::{DType, Device, Result, Shape, Tensor, WithDType};
|
use candle::{DType, Device, Result, Shape, Tensor, WithDType};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@ -13,6 +15,7 @@ pub struct VarBuilder {
|
|||||||
path: Vec<String>,
|
path: Vec<String>,
|
||||||
vars: std::rc::Rc<std::cell::RefCell<Vec<NamedVar>>>,
|
vars: std::rc::Rc<std::cell::RefCell<Vec<NamedVar>>>,
|
||||||
default_dtype: DType,
|
default_dtype: DType,
|
||||||
|
tensors: Arc<Option<HashMap<String, Tensor>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
@ -21,12 +24,13 @@ pub struct VarStore {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl VarBuilder {
|
impl VarBuilder {
|
||||||
pub fn new<B: WithDType>() -> Self {
|
pub fn new<B: WithDType>(tensors: Option<HashMap<String, Tensor>>) -> Self {
|
||||||
let vars = std::rc::Rc::new(std::cell::RefCell::new(vec![]));
|
let vars = std::rc::Rc::new(std::cell::RefCell::new(vec![]));
|
||||||
Self {
|
Self {
|
||||||
path: vec![],
|
path: vec![],
|
||||||
vars,
|
vars,
|
||||||
default_dtype: B::DTYPE,
|
default_dtype: B::DTYPE,
|
||||||
|
tensors: Arc::new(tensors),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -38,13 +42,19 @@ impl VarBuilder {
|
|||||||
let shape = shape.into();
|
let shape = shape.into();
|
||||||
let path = format!("{}.{s}", self.path.join("."));
|
let path = format!("{}.{s}", self.path.join("."));
|
||||||
let mut vars = self.vars.borrow_mut();
|
let mut vars = self.vars.borrow_mut();
|
||||||
let parameter = Tensor::zeros(&shape, self.default_dtype, &Device::Cpu);
|
let parameter = match self.tensors.as_ref() {
|
||||||
|
None => Tensor::zeros(&shape, self.default_dtype, &Device::Cpu)?,
|
||||||
|
Some(tensors) => match tensors.get(&path) {
|
||||||
|
Some(tensor) => tensor.clone(),
|
||||||
|
None => panic!("cannot find tensor for {path}"),
|
||||||
|
},
|
||||||
|
};
|
||||||
vars.push(NamedVar {
|
vars.push(NamedVar {
|
||||||
path,
|
path,
|
||||||
dtype: self.default_dtype,
|
dtype: self.default_dtype,
|
||||||
shape,
|
shape,
|
||||||
});
|
});
|
||||||
parameter
|
Ok(parameter)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn into_store(self) -> VarStore {
|
pub fn into_store(self) -> VarStore {
|
||||||
@ -65,6 +75,7 @@ impl<S: ToString> std::ops::Div<S> for &VarBuilder {
|
|||||||
path,
|
path,
|
||||||
vars: self.vars.clone(),
|
vars: self.vars.clone(),
|
||||||
default_dtype: self.default_dtype,
|
default_dtype: self.default_dtype,
|
||||||
|
tensors: self.tensors.clone(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user