mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Merge pull request #57 from LaurentMazare/safetensor-module2
Move more safetensors bits to the shared module.
This commit is contained in:
@ -11,25 +11,25 @@ license = "MIT/Apache-2.0"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
safetensors = "0.3.1"
|
||||
thiserror = "1"
|
||||
cudarc = { version = "0.9.9", optional = true, features = ["f16"] }
|
||||
candle-kernels = { path = "../candle-kernels", optional = true }
|
||||
gemm = "0.15.4"
|
||||
zip = { version = "0.6.6", default-features=false }
|
||||
byteorder = "1.4.3"
|
||||
candle-kernels = { path = "../candle-kernels", optional = true }
|
||||
cudarc = { version = "0.9.9", optional = true, features = ["f16"] }
|
||||
gemm = "0.15.4"
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
memmap2 = "0.7.1"
|
||||
num-traits = "0.2.15"
|
||||
num_cpus = "1.15.0"
|
||||
safetensors = "0.3.1"
|
||||
thiserror = "1"
|
||||
zip = { version = "0.6.6", default-features=false }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
candle-hub = { path = "../candle-hub" }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
rand = "0.8.5"
|
||||
tokenizers = { version = "0.13.3", default-features=false, features=["onig"] }
|
||||
tokio = { version = "1.28.2", features = ["macros", "rt-multi-thread"] }
|
||||
candle-hub = { path = "../candle-hub" }
|
||||
memmap2 = "0.7.1"
|
||||
|
||||
[features]
|
||||
default = ["cuda"]
|
||||
|
@ -1,8 +1,5 @@
|
||||
use super::*;
|
||||
use candle::{Device, Result, Tensor};
|
||||
use memmap2::MmapOptions;
|
||||
use safetensors::SafeTensors;
|
||||
use std::fs::File;
|
||||
use candle::{safetensors::SafeTensors, Device, Result, Tensor};
|
||||
use std::path::PathBuf;
|
||||
|
||||
pub struct VarBuilder<'a> {
|
||||
@ -30,8 +27,9 @@ impl<'a> VarBuilder<'a> {
|
||||
pub fn get(&self, tensor_name: &str) -> Result<Tensor> {
|
||||
// Unwrap or 0 just to let the proper error flow.
|
||||
let index = self.routing.get(tensor_name).unwrap_or(&0);
|
||||
let view = self.safetensors[*index].tensor(tensor_name).unwrap();
|
||||
candle::safetensors::convert(view, &self.device)?.to_dtype(DTYPE)
|
||||
self.safetensors[*index]
|
||||
.tensor(tensor_name, &self.device)?
|
||||
.to_dtype(DTYPE)
|
||||
}
|
||||
}
|
||||
|
||||
@ -107,18 +105,12 @@ impl Llama {
|
||||
) -> Result<Self> {
|
||||
let handles: Vec<_> = filenames
|
||||
.iter()
|
||||
.map(|f| {
|
||||
let file = File::open(f).unwrap();
|
||||
unsafe { MmapOptions::new().map(&file).unwrap() }
|
||||
})
|
||||
.collect();
|
||||
.map(candle::safetensors::MmapedFile::new)
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let tensors: Vec<_> = handles
|
||||
.iter()
|
||||
.map(|h| {
|
||||
let tensors = SafeTensors::deserialize(h).unwrap();
|
||||
tensors
|
||||
})
|
||||
.collect();
|
||||
.map(|h| h.deserialize())
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let vb = VarBuilder::new(tensors, device.clone());
|
||||
|
||||
|
@ -114,6 +114,9 @@ pub enum Error {
|
||||
#[error(transparent)]
|
||||
SafeTensor(#[from] safetensors::SafeTensorError),
|
||||
|
||||
#[error("unsupported safetensor dtype {0:?}")]
|
||||
UnsupportedSafeTensorDtype(safetensors::Dtype),
|
||||
|
||||
#[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
|
||||
BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape },
|
||||
}
|
||||
|
@ -1,27 +1,76 @@
|
||||
use crate::{Device, Result, Tensor};
|
||||
use half::f16;
|
||||
use crate::{Device, Error, Result, Tensor, WithDType};
|
||||
use safetensors::tensor as st;
|
||||
|
||||
fn convert_<T: WithDType>(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
let v = view.data();
|
||||
let size_in_bytes = T::DTYPE.size_in_bytes();
|
||||
let elem_count = v.len() / size_in_bytes;
|
||||
if (v.as_ptr() as usize) % size_in_bytes == 0 {
|
||||
// SAFETY This is safe because we just checked that this
|
||||
// was correctly aligned.
|
||||
let data: &[T] = unsafe { std::slice::from_raw_parts(v.as_ptr() as *const T, elem_count) };
|
||||
Tensor::from_slice(data, view.shape(), device)
|
||||
} else {
|
||||
let mut c = Vec::with_capacity(elem_count);
|
||||
unsafe {
|
||||
std::ptr::copy_nonoverlapping(v.as_ptr(), c.as_mut_ptr() as *mut u8, v.len());
|
||||
c.set_len(elem_count)
|
||||
}
|
||||
Tensor::from_slice(&c, view.shape(), device)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
match view.dtype() {
|
||||
st::Dtype::F16 => {
|
||||
let v = view.data();
|
||||
if (v.as_ptr() as usize) % 2 == 0 {
|
||||
// SAFETY This is safe because we just checked that this
|
||||
// was correctly aligned.
|
||||
let data: &[f16] =
|
||||
unsafe { std::slice::from_raw_parts(v.as_ptr() as *const f16, v.len() / 2) };
|
||||
Tensor::from_slice(data, view.shape(), device)
|
||||
} else {
|
||||
let mut c = Vec::with_capacity(v.len() / 2);
|
||||
let mut i = 0;
|
||||
while i < v.len() {
|
||||
c.push(f16::from_le_bytes([v[i], v[i + 1]]));
|
||||
i += 2;
|
||||
}
|
||||
Tensor::from_slice(&c, view.shape(), device)
|
||||
}
|
||||
}
|
||||
dt => todo!("Unhandled dtype {dt:?}"),
|
||||
st::Dtype::U8 => convert_::<u8>(view, device),
|
||||
st::Dtype::U32 => convert_::<u8>(view, device),
|
||||
st::Dtype::BF16 => convert_::<half::bf16>(view, device),
|
||||
st::Dtype::F16 => convert_::<half::f16>(view, device),
|
||||
st::Dtype::F32 => convert_::<f32>(view, device),
|
||||
st::Dtype::F64 => convert_::<f64>(view, device),
|
||||
dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
|
||||
}
|
||||
}
|
||||
|
||||
// If Rust allowed for self-referential struct, we could store both the Mmap buffer and the
|
||||
// SafeTensor bits in the same struct and avoid having the final users calling two methods.
|
||||
// We could try using the ouroboros crate or equivalent for this at some point.
|
||||
// Wrap the SafeTensors main module so as to provide accessors with the candle types for errors,
|
||||
// dtypes, etc
|
||||
pub struct SafeTensors<'a>(st::SafeTensors<'a>);
|
||||
|
||||
pub struct MmapedFile(memmap2::Mmap);
|
||||
|
||||
impl MmapedFile {
|
||||
pub fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
|
||||
let file = std::fs::File::open(p)?;
|
||||
let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? };
|
||||
Ok(Self(mmap))
|
||||
}
|
||||
|
||||
pub fn deserialize(&self) -> Result<SafeTensors<'_>> {
|
||||
let st = safetensors::SafeTensors::deserialize(&self.0)?;
|
||||
Ok(SafeTensors(st))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> SafeTensors<'a> {
|
||||
pub fn tensor(&self, name: &str, device: &Device) -> Result<Tensor> {
|
||||
convert(self.0.tensor(name)?, device)
|
||||
}
|
||||
|
||||
pub fn tensors(&self, device: &Device) -> Result<Vec<(String, Tensor)>> {
|
||||
self.0
|
||||
.tensors()
|
||||
.into_iter()
|
||||
.map(|(name, tensor_view)| {
|
||||
let tensor = convert(tensor_view, device)?;
|
||||
Ok((name, tensor))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn names(&self) -> Vec<&String> {
|
||||
self.0.names()
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user