From 783b7054eeaec0ecd6684b537510193813b81d77 Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 3 Jul 2023 09:34:08 +0100 Subject: [PATCH 1/2] Move more safetensors bits to the shared module. --- candle-core/Cargo.toml | 16 +++++----- candle-core/examples/llama/weights.rs | 24 +++++---------- candle-core/src/safetensors.rs | 43 +++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 24 deletions(-) diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 1b7ef4c4..7076e4e4 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -11,25 +11,25 @@ license = "MIT/Apache-2.0" readme = "README.md" [dependencies] -safetensors = "0.3.1" -thiserror = "1" -cudarc = { version = "0.9.9", optional = true, features = ["f16"] } -candle-kernels = { path = "../candle-kernels", optional = true } -gemm = "0.15.4" -zip = { version = "0.6.6", default-features=false } byteorder = "1.4.3" +candle-kernels = { path = "../candle-kernels", optional = true } +cudarc = { version = "0.9.9", optional = true, features = ["f16"] } +gemm = "0.15.4" half = { version = "2.3.1", features = ["num-traits"] } +memmap2 = "0.7.1" num-traits = "0.2.15" num_cpus = "1.15.0" +safetensors = "0.3.1" +thiserror = "1" +zip = { version = "0.6.6", default-features=false } [dev-dependencies] anyhow = { version = "1", features = ["backtrace"] } +candle-hub = { path = "../candle-hub" } clap = { version = "4.2.4", features = ["derive"] } rand = "0.8.5" tokenizers = { version = "0.13.3", default-features=false, features=["onig"] } tokio = { version = "1.28.2", features = ["macros", "rt-multi-thread"] } -candle-hub = { path = "../candle-hub" } -memmap2 = "0.7.1" [features] default = ["cuda"] diff --git a/candle-core/examples/llama/weights.rs b/candle-core/examples/llama/weights.rs index 4ad9b391..cc3fccd4 100644 --- a/candle-core/examples/llama/weights.rs +++ b/candle-core/examples/llama/weights.rs @@ -1,8 +1,5 @@ use super::*; -use candle::{Device, Result, Tensor}; -use memmap2::MmapOptions; -use safetensors::SafeTensors; -use std::fs::File; +use candle::{safetensors::SafeTensors, Device, Result, Tensor}; use std::path::PathBuf; pub struct VarBuilder<'a> { @@ -30,8 +27,9 @@ impl<'a> VarBuilder<'a> { pub fn get(&self, tensor_name: &str) -> Result { // Unwrap or 0 just to let the proper error flow. let index = self.routing.get(tensor_name).unwrap_or(&0); - let view = self.safetensors[*index].tensor(tensor_name).unwrap(); - candle::safetensors::convert(view, &self.device)?.to_dtype(DTYPE) + self.safetensors[*index] + .tensor(tensor_name, &self.device)? + .to_dtype(DTYPE) } } @@ -107,18 +105,12 @@ impl Llama { ) -> Result { let handles: Vec<_> = filenames .iter() - .map(|f| { - let file = File::open(f).unwrap(); - unsafe { MmapOptions::new().map(&file).unwrap() } - }) - .collect(); + .map(candle::safetensors::MmapedFile::new) + .collect::>>()?; let tensors: Vec<_> = handles .iter() - .map(|h| { - let tensors = SafeTensors::deserialize(h).unwrap(); - tensors - }) - .collect(); + .map(|h| h.deserialize()) + .collect::>>()?; let vb = VarBuilder::new(tensors, device.clone()); diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index 3ed36b64..b586b544 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -25,3 +25,46 @@ pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result { dt => todo!("Unhandled dtype {dt:?}"), } } + +// If Rust allowed for self-referential struct, we could store both the Mmap buffer and the +// SafeTensor bits in the same struct and avoid having the final users calling two methods. +// We could try using the ouroboros crate or equivalent for this at some point. +// Wrap the SafeTensors main module so as to provide accessors with the candle types for errors, +// dtypes, etc +pub struct SafeTensors<'a>(st::SafeTensors<'a>); + +pub struct MmapedFile(memmap2::Mmap); + +impl MmapedFile { + pub fn new>(p: P) -> Result { + let file = std::fs::File::open(p)?; + let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? }; + Ok(Self(mmap)) + } + + pub fn deserialize(&self) -> Result> { + let st = safetensors::SafeTensors::deserialize(&self.0)?; + Ok(SafeTensors(st)) + } +} + +impl<'a> SafeTensors<'a> { + pub fn tensor(&self, name: &str, device: &Device) -> Result { + convert(self.0.tensor(name)?, device) + } + + pub fn tensors(&self, device: &Device) -> Result> { + self.0 + .tensors() + .into_iter() + .map(|(name, tensor_view)| { + let tensor = convert(tensor_view, device)?; + Ok((name, tensor)) + }) + .collect() + } + + pub fn names(&self) -> Vec<&String> { + self.0.names() + } +} From 899c76de7567572f522f8711e10150de8e4e0d6f Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 3 Jul 2023 10:09:46 +0100 Subject: [PATCH 2/2] Handle more types in safetensors. --- candle-core/src/error.rs | 3 +++ candle-core/src/safetensors.rs | 48 +++++++++++++++++++--------------- 2 files changed, 30 insertions(+), 21 deletions(-) diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index 71fd21de..d5de4296 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -114,6 +114,9 @@ pub enum Error { #[error(transparent)] SafeTensor(#[from] safetensors::SafeTensorError), + #[error("unsupported safetensor dtype {0:?}")] + UnsupportedSafeTensorDtype(safetensors::Dtype), + #[error("cannot broadcast {src_shape:?} to {dst_shape:?}")] BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape }, } diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index b586b544..b80a756a 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -1,28 +1,34 @@ -use crate::{Device, Result, Tensor}; -use half::f16; +use crate::{Device, Error, Result, Tensor, WithDType}; use safetensors::tensor as st; +fn convert_(view: st::TensorView<'_>, device: &Device) -> Result { + let v = view.data(); + let size_in_bytes = T::DTYPE.size_in_bytes(); + let elem_count = v.len() / size_in_bytes; + if (v.as_ptr() as usize) % size_in_bytes == 0 { + // SAFETY This is safe because we just checked that this + // was correctly aligned. + let data: &[T] = unsafe { std::slice::from_raw_parts(v.as_ptr() as *const T, elem_count) }; + Tensor::from_slice(data, view.shape(), device) + } else { + let mut c = Vec::with_capacity(elem_count); + unsafe { + std::ptr::copy_nonoverlapping(v.as_ptr(), c.as_mut_ptr() as *mut u8, v.len()); + c.set_len(elem_count) + } + Tensor::from_slice(&c, view.shape(), device) + } +} + pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result { match view.dtype() { - st::Dtype::F16 => { - let v = view.data(); - if (v.as_ptr() as usize) % 2 == 0 { - // SAFETY This is safe because we just checked that this - // was correctly aligned. - let data: &[f16] = - unsafe { std::slice::from_raw_parts(v.as_ptr() as *const f16, v.len() / 2) }; - Tensor::from_slice(data, view.shape(), device) - } else { - let mut c = Vec::with_capacity(v.len() / 2); - let mut i = 0; - while i < v.len() { - c.push(f16::from_le_bytes([v[i], v[i + 1]])); - i += 2; - } - Tensor::from_slice(&c, view.shape(), device) - } - } - dt => todo!("Unhandled dtype {dt:?}"), + st::Dtype::U8 => convert_::(view, device), + st::Dtype::U32 => convert_::(view, device), + st::Dtype::BF16 => convert_::(view, device), + st::Dtype::F16 => convert_::(view, device), + st::Dtype::F32 => convert_::(view, device), + st::Dtype::F64 => convert_::(view, device), + dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)), } }