mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Early conversion for the llama weights.
This commit is contained in:
@ -1,7 +1,7 @@
|
|||||||
use super::*;
|
use super::*;
|
||||||
use candle::{DType, Device, Result, Shape, Tensor, WithDType};
|
use candle::{DType, Device, Result, Shape, Tensor};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@ -14,51 +14,28 @@ struct NamedVar {
|
|||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct VarBuilder {
|
pub struct VarBuilder {
|
||||||
path: Vec<String>,
|
path: Vec<String>,
|
||||||
vars: std::rc::Rc<std::cell::RefCell<Vec<NamedVar>>>,
|
|
||||||
default_dtype: DType,
|
|
||||||
default_device: Device,
|
default_device: Device,
|
||||||
tensors: Arc<Option<HashMap<String, Tensor>>>,
|
tensors: Arc<Mutex<HashMap<String, Tensor>>>,
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
pub struct VarStore {
|
|
||||||
vars: Vec<NamedVar>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl VarBuilder {
|
impl VarBuilder {
|
||||||
pub fn new<B: WithDType>(device: &Device, tensors: Option<HashMap<String, Tensor>>) -> Self {
|
pub fn new(device: &Device, tensors: HashMap<String, Tensor>) -> Self {
|
||||||
let vars = std::rc::Rc::new(std::cell::RefCell::new(vec![]));
|
|
||||||
Self {
|
Self {
|
||||||
path: vec![],
|
path: vec![],
|
||||||
vars,
|
tensors: Arc::new(Mutex::new(tensors)),
|
||||||
default_dtype: B::DTYPE,
|
|
||||||
tensors: Arc::new(tensors),
|
|
||||||
default_device: device.clone(),
|
default_device: device.clone(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn len(&self) -> usize {
|
pub fn get_and_remove(&self, s: &str) -> Result<Tensor> {
|
||||||
self.vars.borrow().len()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn var(&self, s: &str) -> Result<Tensor> {
|
|
||||||
let path = format!("{}.{s}", self.path.join("."));
|
let path = format!("{}.{s}", self.path.join("."));
|
||||||
let parameter = match self.tensors.as_ref() {
|
let mut tensors = self.tensors.as_ref().lock().unwrap();
|
||||||
None => panic!("Cannot find tensors"),
|
let parameter = match tensors.remove(&path) {
|
||||||
Some(tensors) => match tensors.get(&path) {
|
Some(tensor) => tensor.to_device(&self.default_device)?,
|
||||||
Some(tensor) => tensor.to_device(&self.default_device)?,
|
None => panic!("cannot find tensor for {path}"),
|
||||||
None => panic!("cannot find tensor for {path}"),
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
Ok(parameter)
|
Ok(parameter)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn into_store(self) -> VarStore {
|
|
||||||
let vars = self.vars.borrow();
|
|
||||||
VarStore {
|
|
||||||
vars: vars.to_vec(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S: ToString> std::ops::Div<S> for &VarBuilder {
|
impl<S: ToString> std::ops::Div<S> for &VarBuilder {
|
||||||
@ -69,8 +46,6 @@ impl<S: ToString> std::ops::Div<S> for &VarBuilder {
|
|||||||
path.push(rhs.to_string());
|
path.push(rhs.to_string());
|
||||||
VarBuilder {
|
VarBuilder {
|
||||||
path,
|
path,
|
||||||
vars: self.vars.clone(),
|
|
||||||
default_dtype: self.default_dtype,
|
|
||||||
default_device: self.default_device.clone(),
|
default_device: self.default_device.clone(),
|
||||||
tensors: self.tensors.clone(),
|
tensors: self.tensors.clone(),
|
||||||
}
|
}
|
||||||
@ -87,21 +62,21 @@ impl<S: ToString> std::ops::Div<S> for VarBuilder {
|
|||||||
|
|
||||||
impl Embedding {
|
impl Embedding {
|
||||||
fn load_npy(vb: VarBuilder) -> Result<Self> {
|
fn load_npy(vb: VarBuilder) -> Result<Self> {
|
||||||
let embeddings = vb.var("weight")?;
|
let embeddings = vb.get_and_remove("weight")?.to_dtype(DTYPE)?;
|
||||||
Ok(Self { embeddings })
|
Ok(Self { embeddings })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Linear {
|
impl Linear {
|
||||||
fn load_npy(vb: VarBuilder) -> Result<Self> {
|
fn load_npy(vb: VarBuilder) -> Result<Self> {
|
||||||
let weight = vb.var("weight")?.t()?;
|
let weight = vb.get_and_remove("weight")?.to_dtype(DTYPE)?.t()?;
|
||||||
Ok(Self { weight })
|
Ok(Self { weight })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RmsNorm {
|
impl RmsNorm {
|
||||||
fn load_npy(vb: VarBuilder) -> Result<Self> {
|
fn load_npy(vb: VarBuilder) -> Result<Self> {
|
||||||
let scale = vb.var("scale")?;
|
let scale = vb.get_and_remove("scale")?.to_dtype(DTYPE)?;
|
||||||
Ok(Self::new(scale))
|
Ok(Self::new(scale))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -144,7 +119,7 @@ impl Llama {
|
|||||||
filename: &str,
|
filename: &str,
|
||||||
cache: &Cache,
|
cache: &Cache,
|
||||||
config: &Config,
|
config: &Config,
|
||||||
) -> Result<Self> {
|
) -> anyhow::Result<Self> {
|
||||||
let weight_path = std::path::Path::new(filename);
|
let weight_path = std::path::Path::new(filename);
|
||||||
let weights = if weight_path.exists() {
|
let weights = if weight_path.exists() {
|
||||||
println!("loading weights from {weight_path:?}");
|
println!("loading weights from {weight_path:?}");
|
||||||
@ -152,12 +127,11 @@ impl Llama {
|
|||||||
let tensors = Tensor::read_npz(weight_path)?;
|
let tensors = Tensor::read_npz(weight_path)?;
|
||||||
println!("loaded weights in {:?}", start_load.elapsed());
|
println!("loaded weights in {:?}", start_load.elapsed());
|
||||||
let tensors: std::collections::HashMap<String, Tensor> = tensors.into_iter().collect();
|
let tensors: std::collections::HashMap<String, Tensor> = tensors.into_iter().collect();
|
||||||
Some(tensors)
|
tensors
|
||||||
} else {
|
} else {
|
||||||
println!("cannot find {weight_path:?}, using zero weights");
|
anyhow::bail!("cannot find {weight_path:?}")
|
||||||
None
|
|
||||||
};
|
};
|
||||||
let vb = VarBuilder::new::<f32>(device, weights);
|
let vb = VarBuilder::new(device, weights);
|
||||||
|
|
||||||
let wte = Embedding::load_npy(&vb / "transformer" / "wte")?;
|
let wte = Embedding::load_npy(&vb / "transformer" / "wte")?;
|
||||||
let lm_head = Linear::load_npy(&vb / "lm_head")?;
|
let lm_head = Linear::load_npy(&vb / "lm_head")?;
|
||||||
|
@ -18,7 +18,7 @@ fn convert(view: TensorView<'_>, device: &Device) -> Result<Tensor> {
|
|||||||
// was correctly aligned.
|
// was correctly aligned.
|
||||||
let data: &[f16] =
|
let data: &[f16] =
|
||||||
unsafe { std::slice::from_raw_parts(v.as_ptr() as *const f16, v.len() / 2) };
|
unsafe { std::slice::from_raw_parts(v.as_ptr() as *const f16, v.len() / 2) };
|
||||||
Tensor::from_slice(data, view.shape(), device)
|
Tensor::from_slice(data, view.shape(), device)?.to_dtype(DTYPE)
|
||||||
} else {
|
} else {
|
||||||
let mut c = Vec::with_capacity(v.len() / 2);
|
let mut c = Vec::with_capacity(v.len() / 2);
|
||||||
let mut i = 0;
|
let mut i = 0;
|
||||||
@ -26,7 +26,7 @@ fn convert(view: TensorView<'_>, device: &Device) -> Result<Tensor> {
|
|||||||
c.push(f16::from_le_bytes([v[i], v[i + 1]]));
|
c.push(f16::from_le_bytes([v[i], v[i + 1]]));
|
||||||
i += 2;
|
i += 2;
|
||||||
}
|
}
|
||||||
Tensor::from_slice(&c, view.shape(), device)
|
Tensor::from_slice(&c, view.shape(), device)?.to_dtype(DTYPE)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
dt => todo!("Unhandled dtype {dt:?}"),
|
dt => todo!("Unhandled dtype {dt:?}"),
|
||||||
|
Reference in New Issue
Block a user