mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Move more safetensors bits to the shared module.
This commit is contained in:
@ -11,25 +11,25 @@ license = "MIT/Apache-2.0"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
safetensors = "0.3.1"
|
||||
thiserror = "1"
|
||||
cudarc = { version = "0.9.9", optional = true, features = ["f16"] }
|
||||
candle-kernels = { path = "../candle-kernels", optional = true }
|
||||
gemm = "0.15.4"
|
||||
zip = { version = "0.6.6", default-features=false }
|
||||
byteorder = "1.4.3"
|
||||
candle-kernels = { path = "../candle-kernels", optional = true }
|
||||
cudarc = { version = "0.9.9", optional = true, features = ["f16"] }
|
||||
gemm = "0.15.4"
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
memmap2 = "0.7.1"
|
||||
num-traits = "0.2.15"
|
||||
num_cpus = "1.15.0"
|
||||
safetensors = "0.3.1"
|
||||
thiserror = "1"
|
||||
zip = { version = "0.6.6", default-features=false }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
candle-hub = { path = "../candle-hub" }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
rand = "0.8.5"
|
||||
tokenizers = { version = "0.13.3", default-features=false, features=["onig"] }
|
||||
tokio = { version = "1.28.2", features = ["macros", "rt-multi-thread"] }
|
||||
candle-hub = { path = "../candle-hub" }
|
||||
memmap2 = "0.7.1"
|
||||
|
||||
[features]
|
||||
default = ["cuda"]
|
||||
|
@ -1,8 +1,5 @@
|
||||
use super::*;
|
||||
use candle::{Device, Result, Tensor};
|
||||
use memmap2::MmapOptions;
|
||||
use safetensors::SafeTensors;
|
||||
use std::fs::File;
|
||||
use candle::{safetensors::SafeTensors, Device, Result, Tensor};
|
||||
use std::path::PathBuf;
|
||||
|
||||
pub struct VarBuilder<'a> {
|
||||
@ -30,8 +27,9 @@ impl<'a> VarBuilder<'a> {
|
||||
pub fn get(&self, tensor_name: &str) -> Result<Tensor> {
|
||||
// Unwrap or 0 just to let the proper error flow.
|
||||
let index = self.routing.get(tensor_name).unwrap_or(&0);
|
||||
let view = self.safetensors[*index].tensor(tensor_name).unwrap();
|
||||
candle::safetensors::convert(view, &self.device)?.to_dtype(DTYPE)
|
||||
self.safetensors[*index]
|
||||
.tensor(tensor_name, &self.device)?
|
||||
.to_dtype(DTYPE)
|
||||
}
|
||||
}
|
||||
|
||||
@ -107,18 +105,12 @@ impl Llama {
|
||||
) -> Result<Self> {
|
||||
let handles: Vec<_> = filenames
|
||||
.iter()
|
||||
.map(|f| {
|
||||
let file = File::open(f).unwrap();
|
||||
unsafe { MmapOptions::new().map(&file).unwrap() }
|
||||
})
|
||||
.collect();
|
||||
.map(candle::safetensors::MmapedFile::new)
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let tensors: Vec<_> = handles
|
||||
.iter()
|
||||
.map(|h| {
|
||||
let tensors = SafeTensors::deserialize(h).unwrap();
|
||||
tensors
|
||||
})
|
||||
.collect();
|
||||
.map(|h| h.deserialize())
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let vb = VarBuilder::new(tensors, device.clone());
|
||||
|
||||
|
@ -25,3 +25,46 @@ pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
dt => todo!("Unhandled dtype {dt:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
// If Rust allowed for self-referential struct, we could store both the Mmap buffer and the
|
||||
// SafeTensor bits in the same struct and avoid having the final users calling two methods.
|
||||
// We could try using the ouroboros crate or equivalent for this at some point.
|
||||
// Wrap the SafeTensors main module so as to provide accessors with the candle types for errors,
|
||||
// dtypes, etc
|
||||
pub struct SafeTensors<'a>(st::SafeTensors<'a>);
|
||||
|
||||
pub struct MmapedFile(memmap2::Mmap);
|
||||
|
||||
impl MmapedFile {
|
||||
pub fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
|
||||
let file = std::fs::File::open(p)?;
|
||||
let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? };
|
||||
Ok(Self(mmap))
|
||||
}
|
||||
|
||||
pub fn deserialize(&self) -> Result<SafeTensors<'_>> {
|
||||
let st = safetensors::SafeTensors::deserialize(&self.0)?;
|
||||
Ok(SafeTensors(st))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> SafeTensors<'a> {
|
||||
pub fn tensor(&self, name: &str, device: &Device) -> Result<Tensor> {
|
||||
convert(self.0.tensor(name)?, device)
|
||||
}
|
||||
|
||||
pub fn tensors(&self, device: &Device) -> Result<Vec<(String, Tensor)>> {
|
||||
self.0
|
||||
.tensors()
|
||||
.into_iter()
|
||||
.map(|(name, tensor_view)| {
|
||||
let tensor = convert(tensor_view, device)?;
|
||||
Ok((name, tensor))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn names(&self) -> Vec<&String> {
|
||||
self.0.names()
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user