diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 1b7ef4c4..7076e4e4 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -11,25 +11,25 @@ license = "MIT/Apache-2.0" readme = "README.md" [dependencies] -safetensors = "0.3.1" -thiserror = "1" -cudarc = { version = "0.9.9", optional = true, features = ["f16"] } -candle-kernels = { path = "../candle-kernels", optional = true } -gemm = "0.15.4" -zip = { version = "0.6.6", default-features=false } byteorder = "1.4.3" +candle-kernels = { path = "../candle-kernels", optional = true } +cudarc = { version = "0.9.9", optional = true, features = ["f16"] } +gemm = "0.15.4" half = { version = "2.3.1", features = ["num-traits"] } +memmap2 = "0.7.1" num-traits = "0.2.15" num_cpus = "1.15.0" +safetensors = "0.3.1" +thiserror = "1" +zip = { version = "0.6.6", default-features=false } [dev-dependencies] anyhow = { version = "1", features = ["backtrace"] } +candle-hub = { path = "../candle-hub" } clap = { version = "4.2.4", features = ["derive"] } rand = "0.8.5" tokenizers = { version = "0.13.3", default-features=false, features=["onig"] } tokio = { version = "1.28.2", features = ["macros", "rt-multi-thread"] } -candle-hub = { path = "../candle-hub" } -memmap2 = "0.7.1" [features] default = ["cuda"] diff --git a/candle-core/examples/llama/weights.rs b/candle-core/examples/llama/weights.rs index 4ad9b391..cc3fccd4 100644 --- a/candle-core/examples/llama/weights.rs +++ b/candle-core/examples/llama/weights.rs @@ -1,8 +1,5 @@ use super::*; -use candle::{Device, Result, Tensor}; -use memmap2::MmapOptions; -use safetensors::SafeTensors; -use std::fs::File; +use candle::{safetensors::SafeTensors, Device, Result, Tensor}; use std::path::PathBuf; pub struct VarBuilder<'a> { @@ -30,8 +27,9 @@ impl<'a> VarBuilder<'a> { pub fn get(&self, tensor_name: &str) -> Result { // Unwrap or 0 just to let the proper error flow. let index = self.routing.get(tensor_name).unwrap_or(&0); - let view = self.safetensors[*index].tensor(tensor_name).unwrap(); - candle::safetensors::convert(view, &self.device)?.to_dtype(DTYPE) + self.safetensors[*index] + .tensor(tensor_name, &self.device)? + .to_dtype(DTYPE) } } @@ -107,18 +105,12 @@ impl Llama { ) -> Result { let handles: Vec<_> = filenames .iter() - .map(|f| { - let file = File::open(f).unwrap(); - unsafe { MmapOptions::new().map(&file).unwrap() } - }) - .collect(); + .map(candle::safetensors::MmapedFile::new) + .collect::>>()?; let tensors: Vec<_> = handles .iter() - .map(|h| { - let tensors = SafeTensors::deserialize(h).unwrap(); - tensors - }) - .collect(); + .map(|h| h.deserialize()) + .collect::>>()?; let vb = VarBuilder::new(tensors, device.clone()); diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index 3ed36b64..b586b544 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -25,3 +25,46 @@ pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result { dt => todo!("Unhandled dtype {dt:?}"), } } + +// If Rust allowed for self-referential struct, we could store both the Mmap buffer and the +// SafeTensor bits in the same struct and avoid having the final users calling two methods. +// We could try using the ouroboros crate or equivalent for this at some point. +// Wrap the SafeTensors main module so as to provide accessors with the candle types for errors, +// dtypes, etc +pub struct SafeTensors<'a>(st::SafeTensors<'a>); + +pub struct MmapedFile(memmap2::Mmap); + +impl MmapedFile { + pub fn new>(p: P) -> Result { + let file = std::fs::File::open(p)?; + let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? }; + Ok(Self(mmap)) + } + + pub fn deserialize(&self) -> Result> { + let st = safetensors::SafeTensors::deserialize(&self.0)?; + Ok(SafeTensors(st)) + } +} + +impl<'a> SafeTensors<'a> { + pub fn tensor(&self, name: &str, device: &Device) -> Result { + convert(self.0.tensor(name)?, device) + } + + pub fn tensors(&self, device: &Device) -> Result> { + self.0 + .tensors() + .into_iter() + .map(|(name, tensor_view)| { + let tensor = convert(tensor_view, device)?; + Ok((name, tensor)) + }) + .collect() + } + + pub fn names(&self) -> Vec<&String> { + self.0.names() + } +}