mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
feat: add pth varbuilder (#1108)
This commit is contained in:
@ -191,6 +191,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct Zeros;
|
struct Zeros;
|
||||||
|
|
||||||
impl SimpleBackend for Zeros {
|
impl SimpleBackend for Zeros {
|
||||||
fn get(&self, s: Shape, _: &str, _: crate::Init, dtype: DType, dev: &Device) -> Result<Tensor> {
|
fn get(&self, s: Shape, _: &str, _: crate::Init, dtype: DType, dev: &Device) -> Result<Tensor> {
|
||||||
Tensor::zeros(s, dtype, dev)
|
Tensor::zeros(s, dtype, dev)
|
||||||
@ -325,6 +326,39 @@ impl SimpleBackend for candle::npy::NpzTensors {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl SimpleBackend for candle::pickle::PthTensors {
|
||||||
|
fn get(
|
||||||
|
&self,
|
||||||
|
s: Shape,
|
||||||
|
path: &str,
|
||||||
|
_: crate::Init,
|
||||||
|
dtype: DType,
|
||||||
|
dev: &Device,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let tensor = match self.get(path)? {
|
||||||
|
None => Err(Error::CannotFindTensor {
|
||||||
|
path: path.to_string(),
|
||||||
|
}
|
||||||
|
.bt())?,
|
||||||
|
Some(tensor) => tensor,
|
||||||
|
};
|
||||||
|
let tensor = tensor.to_device(dev)?.to_dtype(dtype)?;
|
||||||
|
if tensor.shape() != &s {
|
||||||
|
Err(candle::Error::UnexpectedShape {
|
||||||
|
msg: format!("shape mismatch for {path}"),
|
||||||
|
expected: s,
|
||||||
|
got: tensor.shape().clone(),
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
Ok(tensor)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn contains_tensor(&self, name: &str) -> bool {
|
||||||
|
self.get(name).map_or(false, |v| v.is_some())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl SimpleBackend for candle::safetensors::MmapedSafetensors {
|
impl SimpleBackend for candle::safetensors::MmapedSafetensors {
|
||||||
fn get(
|
fn get(
|
||||||
&self,
|
&self,
|
||||||
@ -438,9 +472,16 @@ impl<'a> VarBuilder<'a> {
|
|||||||
let npz = candle::npy::NpzTensors::new(p)?;
|
let npz = candle::npy::NpzTensors::new(p)?;
|
||||||
Ok(Self::new(Box::new(npz), dtype, dev.clone()))
|
Ok(Self::new(Box::new(npz), dtype, dev.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file.
|
||||||
|
pub fn from_pth<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
|
||||||
|
let pth = candle::pickle::PthTensors::new(p)?;
|
||||||
|
Ok(Self::new(Box::new(pth), dtype, dev.clone()))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct ShardedSafeTensors(candle::safetensors::MmapedSafetensors);
|
pub struct ShardedSafeTensors(candle::safetensors::MmapedSafetensors);
|
||||||
|
|
||||||
pub type ShardedVarBuilder<'a> = VarBuilderArgs<'a, ShardedSafeTensors>;
|
pub type ShardedVarBuilder<'a> = VarBuilderArgs<'a, ShardedSafeTensors>;
|
||||||
|
|
||||||
impl ShardedSafeTensors {
|
impl ShardedSafeTensors {
|
||||||
|
Reference in New Issue
Block a user