Resurrect the llama npy support. (#140)

This commit is contained in:
Laurent Mazare
2023-07-11 19:32:10 +01:00
committed by GitHub
parent 760f1d7055
commit 37cad85869
6 changed files with 264 additions and 90 deletions

View File

@ -1,15 +1,20 @@
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
use candle::{safetensors::SafeTensors, DType, Device, Error, Shape, Tensor};
use std::collections::HashMap;
use std::sync::Arc;
struct SafeTensorWithRouting<'a> {
routing: HashMap<String, usize>,
safetensors: Vec<SafeTensors<'a>>,
// TODO: Maybe we would want the storage to be generic, e.g. with Box<dyn> to avoid too many
// generics.
enum Tensors<'a> {
SafeTensorWithRouting {
routing: HashMap<String, usize>,
safetensors: Vec<SafeTensors<'a>>,
},
TensorMap(HashMap<String, Tensor>),
Zeros,
}
struct TensorData<'a> {
// TODO: Make this part generic, probably via some Box<dyn> to avoid too much generics.
safetensors: Option<SafeTensorWithRouting<'a>>,
tensors: Tensors<'a>,
pub dtype: DType,
pub device: Device,
}
@ -22,12 +27,12 @@ impl<'a> TensorData<'a> {
routing.insert(k.to_string(), index);
}
}
let safetensors = SafeTensorWithRouting {
let tensors = Tensors::SafeTensorWithRouting {
routing,
safetensors,
};
Self {
safetensors: Some(safetensors),
tensors,
device: device.clone(),
dtype,
}
@ -35,7 +40,15 @@ impl<'a> TensorData<'a> {
fn zeros(dtype: DType, device: &Device) -> Self {
Self {
safetensors: None,
tensors: Tensors::Zeros,
device: device.clone(),
dtype,
}
}
fn from_tensors(tensors: HashMap<String, Tensor>, dtype: DType, device: &Device) -> Self {
Self {
tensors: Tensors::TensorMap(tensors),
device: device.clone(),
dtype,
}
@ -67,6 +80,14 @@ impl<'a> VarBuilder<'a> {
}
}
pub fn from_tensors(ts: HashMap<String, Tensor>, dtype: DType, device: &Device) -> Self {
let data = TensorData::from_tensors(ts, dtype, device);
Self {
data: Arc::new(data),
path: vec![],
}
}
pub fn push_prefix(&self, s: &str) -> Self {
let mut path = self.path.clone();
path.push(s.to_string());
@ -94,31 +115,37 @@ impl<'a> VarBuilder<'a> {
pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> candle::Result<Tensor> {
let data = self.data.as_ref();
let s: Shape = s.into();
match &self.data.safetensors {
None => Tensor::zeros(s, data.dtype, &data.device),
Some(SafeTensorWithRouting {
let path = if self.path.is_empty() {
tensor_name.to_string()
} else {
[&self.path.join("."), tensor_name].join(".")
};
let tensor = match &self.data.tensors {
Tensors::Zeros => Tensor::zeros(&s, data.dtype, &data.device)?.contiguous()?,
Tensors::TensorMap(ts) => ts
.get(&path)
.ok_or_else(|| Error::CannotFindTensor {
path: path.to_string(),
})?
.clone(),
Tensors::SafeTensorWithRouting {
routing,
safetensors,
}) => {
let path = if self.path.is_empty() {
tensor_name.to_string()
} else {
[&self.path.join("."), tensor_name].join(".")
};
} => {
// Unwrap or 0 just to let the proper error flow.
let index = routing.get(&path).unwrap_or(&0);
let tensor = safetensors[*index]
safetensors[*index]
.tensor(&path, &data.device)?
.to_dtype(data.dtype)?;
if *tensor.shape() != s {
Err(candle::Error::UnexpectedShape {
msg: format!("shape mismatch for {path}"),
expected: s,
got: tensor.shape().clone(),
})?
}
Ok(tensor)
.to_dtype(data.dtype)?
}
};
if tensor.shape() != &s {
Err(candle::Error::UnexpectedShape {
msg: format!("shape mismatch for {path}"),
expected: s,
got: tensor.shape().clone(),
})?
}
Ok(tensor)
}
}