Merge pull request #57 from LaurentMazare/safetensor-module2

Move more safetensors bits to the shared module.
This commit is contained in:
Laurent Mazare
2023-07-03 10:19:57 +01:00
committed by GitHub
4 changed files with 89 additions and 45 deletions

View File

@ -11,25 +11,25 @@ license = "MIT/Apache-2.0"
readme = "README.md"
[dependencies]
safetensors = "0.3.1"
thiserror = "1"
cudarc = { version = "0.9.9", optional = true, features = ["f16"] }
candle-kernels = { path = "../candle-kernels", optional = true }
gemm = "0.15.4"
zip = { version = "0.6.6", default-features=false }
byteorder = "1.4.3"
candle-kernels = { path = "../candle-kernels", optional = true }
cudarc = { version = "0.9.9", optional = true, features = ["f16"] }
gemm = "0.15.4"
half = { version = "2.3.1", features = ["num-traits"] }
memmap2 = "0.7.1"
num-traits = "0.2.15"
num_cpus = "1.15.0"
safetensors = "0.3.1"
thiserror = "1"
zip = { version = "0.6.6", default-features=false }
[dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
candle-hub = { path = "../candle-hub" }
clap = { version = "4.2.4", features = ["derive"] }
rand = "0.8.5"
tokenizers = { version = "0.13.3", default-features=false, features=["onig"] }
tokio = { version = "1.28.2", features = ["macros", "rt-multi-thread"] }
candle-hub = { path = "../candle-hub" }
memmap2 = "0.7.1"
[features]
default = ["cuda"]

View File

@ -1,8 +1,5 @@
use super::*;
use candle::{Device, Result, Tensor};
use memmap2::MmapOptions;
use safetensors::SafeTensors;
use std::fs::File;
use candle::{safetensors::SafeTensors, Device, Result, Tensor};
use std::path::PathBuf;
pub struct VarBuilder<'a> {
@ -30,8 +27,9 @@ impl<'a> VarBuilder<'a> {
pub fn get(&self, tensor_name: &str) -> Result<Tensor> {
// Unwrap or 0 just to let the proper error flow.
let index = self.routing.get(tensor_name).unwrap_or(&0);
let view = self.safetensors[*index].tensor(tensor_name).unwrap();
candle::safetensors::convert(view, &self.device)?.to_dtype(DTYPE)
self.safetensors[*index]
.tensor(tensor_name, &self.device)?
.to_dtype(DTYPE)
}
}
@ -107,18 +105,12 @@ impl Llama {
) -> Result<Self> {
let handles: Vec<_> = filenames
.iter()
.map(|f| {
let file = File::open(f).unwrap();
unsafe { MmapOptions::new().map(&file).unwrap() }
})
.collect();
.map(candle::safetensors::MmapedFile::new)
.collect::<Result<Vec<_>>>()?;
let tensors: Vec<_> = handles
.iter()
.map(|h| {
let tensors = SafeTensors::deserialize(h).unwrap();
tensors
})
.collect();
.map(|h| h.deserialize())
.collect::<Result<Vec<_>>>()?;
let vb = VarBuilder::new(tensors, device.clone());

View File

@ -114,6 +114,9 @@ pub enum Error {
#[error(transparent)]
SafeTensor(#[from] safetensors::SafeTensorError),
#[error("unsupported safetensor dtype {0:?}")]
UnsupportedSafeTensorDtype(safetensors::Dtype),
#[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape },
}

View File

@ -1,27 +1,76 @@
use crate::{Device, Result, Tensor};
use half::f16;
use crate::{Device, Error, Result, Tensor, WithDType};
use safetensors::tensor as st;
fn convert_<T: WithDType>(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
let v = view.data();
let size_in_bytes = T::DTYPE.size_in_bytes();
let elem_count = v.len() / size_in_bytes;
if (v.as_ptr() as usize) % size_in_bytes == 0 {
// SAFETY This is safe because we just checked that this
// was correctly aligned.
let data: &[T] = unsafe { std::slice::from_raw_parts(v.as_ptr() as *const T, elem_count) };
Tensor::from_slice(data, view.shape(), device)
} else {
let mut c = Vec::with_capacity(elem_count);
unsafe {
std::ptr::copy_nonoverlapping(v.as_ptr(), c.as_mut_ptr() as *mut u8, v.len());
c.set_len(elem_count)
}
Tensor::from_slice(&c, view.shape(), device)
}
}
pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
match view.dtype() {
st::Dtype::F16 => {
let v = view.data();
if (v.as_ptr() as usize) % 2 == 0 {
// SAFETY This is safe because we just checked that this
// was correctly aligned.
let data: &[f16] =
unsafe { std::slice::from_raw_parts(v.as_ptr() as *const f16, v.len() / 2) };
Tensor::from_slice(data, view.shape(), device)
} else {
let mut c = Vec::with_capacity(v.len() / 2);
let mut i = 0;
while i < v.len() {
c.push(f16::from_le_bytes([v[i], v[i + 1]]));
i += 2;
}
Tensor::from_slice(&c, view.shape(), device)
}
}
dt => todo!("Unhandled dtype {dt:?}"),
st::Dtype::U8 => convert_::<u8>(view, device),
st::Dtype::U32 => convert_::<u8>(view, device),
st::Dtype::BF16 => convert_::<half::bf16>(view, device),
st::Dtype::F16 => convert_::<half::f16>(view, device),
st::Dtype::F32 => convert_::<f32>(view, device),
st::Dtype::F64 => convert_::<f64>(view, device),
dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
}
}
// If Rust allowed for self-referential struct, we could store both the Mmap buffer and the
// SafeTensor bits in the same struct and avoid having the final users calling two methods.
// We could try using the ouroboros crate or equivalent for this at some point.
// Wrap the SafeTensors main module so as to provide accessors with the candle types for errors,
// dtypes, etc
pub struct SafeTensors<'a>(st::SafeTensors<'a>);
pub struct MmapedFile(memmap2::Mmap);
impl MmapedFile {
pub fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
let file = std::fs::File::open(p)?;
let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? };
Ok(Self(mmap))
}
pub fn deserialize(&self) -> Result<SafeTensors<'_>> {
let st = safetensors::SafeTensors::deserialize(&self.0)?;
Ok(SafeTensors(st))
}
}
impl<'a> SafeTensors<'a> {
pub fn tensor(&self, name: &str, device: &Device) -> Result<Tensor> {
convert(self.0.tensor(name)?, device)
}
pub fn tensors(&self, device: &Device) -> Result<Vec<(String, Tensor)>> {
self.0
.tensors()
.into_iter()
.map(|(name, tensor_view)| {
let tensor = convert(tensor_view, device)?;
Ok((name, tensor))
})
.collect()
}
pub fn names(&self) -> Vec<&String> {
self.0.names()
}
}