From 122da875806f274a7aa9048f76d7a676b473e56f Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Mon, 16 Oct 2023 17:20:36 +0200 Subject: [PATCH] feat: add pth varbuilder (#1108) --- candle-nn/src/var_builder.rs | 41 ++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 24832bc7..cbd238dd 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -191,6 +191,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> { } struct Zeros; + impl SimpleBackend for Zeros { fn get(&self, s: Shape, _: &str, _: crate::Init, dtype: DType, dev: &Device) -> Result { Tensor::zeros(s, dtype, dev) @@ -325,6 +326,39 @@ impl SimpleBackend for candle::npy::NpzTensors { } } +impl SimpleBackend for candle::pickle::PthTensors { + fn get( + &self, + s: Shape, + path: &str, + _: crate::Init, + dtype: DType, + dev: &Device, + ) -> Result { + let tensor = match self.get(path)? { + None => Err(Error::CannotFindTensor { + path: path.to_string(), + } + .bt())?, + Some(tensor) => tensor, + }; + let tensor = tensor.to_device(dev)?.to_dtype(dtype)?; + if tensor.shape() != &s { + Err(candle::Error::UnexpectedShape { + msg: format!("shape mismatch for {path}"), + expected: s, + got: tensor.shape().clone(), + } + .bt())? + } + Ok(tensor) + } + + fn contains_tensor(&self, name: &str) -> bool { + self.get(name).map_or(false, |v| v.is_some()) + } +} + impl SimpleBackend for candle::safetensors::MmapedSafetensors { fn get( &self, @@ -438,9 +472,16 @@ impl<'a> VarBuilder<'a> { let npz = candle::npy::NpzTensors::new(p)?; Ok(Self::new(Box::new(npz), dtype, dev.clone())) } + + /// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file. + pub fn from_pth>(p: P, dtype: DType, dev: &Device) -> Result { + let pth = candle::pickle::PthTensors::new(p)?; + Ok(Self::new(Box::new(pth), dtype, dev.clone())) + } } pub struct ShardedSafeTensors(candle::safetensors::MmapedSafetensors); + pub type ShardedVarBuilder<'a> = VarBuilderArgs<'a, ShardedSafeTensors>; impl ShardedSafeTensors {