Self-contained safetensor wrappers (#946)

* Self-contained safetensor wrappers.

* Use the new safetensor container in varbuilders.
This commit is contained in:
Laurent Mazare
2023-09-23 20:39:52 +01:00
committed by GitHub
parent 5dbe46b389
commit 890d069092
3 changed files with 61 additions and 30 deletions

View File

@ -325,6 +325,32 @@ impl SimpleBackend for candle::npy::NpzTensors {
}
}
impl SimpleBackend for candle::safetensors::MmapedSafetensors {
fn get(
&self,
s: Shape,
name: &str,
_: crate::Init,
dtype: DType,
dev: &Device,
) -> Result<Tensor> {
let tensor = self.load(name, dev)?.to_dtype(dtype)?;
if tensor.shape() != &s {
Err(candle::Error::UnexpectedShape {
msg: format!("shape mismatch for {name}"),
expected: s,
got: tensor.shape().clone(),
}
.bt())?
}
Ok(tensor)
}
fn contains_tensor(&self, name: &str) -> bool {
self.get(name).is_ok()
}
}
impl<'a> VarBuilder<'a> {
fn new(backend: Box<dyn SimpleBackend + 'a>, dtype: DType, device: Device) -> Self {
let data = TensorData {
@ -361,7 +387,7 @@ impl<'a> VarBuilder<'a> {
}
/// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors
/// files.
/// data.
pub fn from_safetensors(safetensors: Vec<SafeTensors<'a>>, dtype: DType, dev: &Device) -> Self {
let mut routing = HashMap::new();
for (index, sf) in safetensors.iter().enumerate() {
@ -376,6 +402,21 @@ impl<'a> VarBuilder<'a> {
Self::new(Box::new(tensors), dtype, dev.clone())
}
/// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors
/// files.
///
/// # Safety
///
/// The unsafe is inherited from [`memmap2::MmapOptions`].
pub unsafe fn from_mmaped_safetensors<P: AsRef<std::path::Path>>(
paths: &[P],
dtype: DType,
dev: &Device,
) -> Result<Self> {
let tensors = candle::safetensors::MmapedSafetensors::multi(paths)?;
Ok(Self::new(Box::new(tensors), dtype, dev.clone()))
}
/// Initializes a `VarBuilder` that retrieves tensors stored in a numpy npz file.
pub fn from_npz<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
let npz = candle::npy::NpzTensors::new(p)?;