mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Add a convenient way to rename tensors accessed through a varbuilder. (#2052)
This commit is contained in:
@ -498,6 +498,53 @@ impl<'a> VarBuilder<'a> {
|
||||
let pth = candle::pickle::PthTensors::new(p, None)?;
|
||||
Ok(Self::from_backend(Box::new(pth), dtype, dev.clone()))
|
||||
}
|
||||
|
||||
/// Gets a VarBuilder that applies some renaming function on tensor it gets queried for before
|
||||
/// passing the new names to the inner VarBuilder.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle::{Tensor, DType, Device};
|
||||
///
|
||||
/// let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?;
|
||||
/// let tensors: std::collections::HashMap<_, _> = [
|
||||
/// ("foo".to_string(), a),
|
||||
/// ]
|
||||
/// .into_iter()
|
||||
/// .collect();
|
||||
/// let vb = candle_nn::VarBuilder::from_tensors(tensors, DType::F32, &Device::Cpu);
|
||||
/// assert!(vb.contains_tensor("foo"));
|
||||
/// assert!(vb.get((2, 3), "foo").is_ok());
|
||||
/// assert!(!vb.contains_tensor("bar"));
|
||||
/// let vb = vb.rename_f(|f: &str| if f == "bar" { "foo".to_string() } else { f.to_string() });
|
||||
/// assert!(vb.contains_tensor("bar"));
|
||||
/// assert!(vb.contains_tensor("foo"));
|
||||
/// assert!(vb.get((2, 3), "bar").is_ok());
|
||||
/// assert!(vb.get((2, 3), "foo").is_ok());
|
||||
/// assert!(!vb.contains_tensor("baz"));
|
||||
/// # Ok::<(), candle::Error>(())
|
||||
/// ```
|
||||
pub fn rename_f<F: Fn(&str) -> String + Sync + Send + 'static>(self, f: F) -> Self {
|
||||
let f: Box<dyn Fn(&str) -> String + Sync + Send + 'static> = Box::new(f);
|
||||
self.rename(f)
|
||||
}
|
||||
|
||||
pub fn rename<R: Renamer + Send + Sync + 'a>(self, renamer: R) -> Self {
|
||||
let dtype = self.dtype();
|
||||
let device = self.device().clone();
|
||||
let path = self.path.clone();
|
||||
let backend = Rename::new(self, renamer);
|
||||
let backend: Box<dyn SimpleBackend + 'a> = Box::new(backend);
|
||||
let data = TensorData {
|
||||
backend,
|
||||
dtype,
|
||||
device,
|
||||
};
|
||||
Self {
|
||||
data: Arc::new(data),
|
||||
path,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ShardedSafeTensors(candle::safetensors::MmapedSafetensors);
|
||||
@ -618,3 +665,49 @@ impl Backend for ShardedSafeTensors {
|
||||
self.0.get(name).is_ok()
|
||||
}
|
||||
}
|
||||
|
||||
/// This traits specifies a way to rename the queried names into names that are stored in an inner
|
||||
/// VarBuilder.
|
||||
pub trait Renamer {
|
||||
/// This is applied to the name obtained by a name call and the resulting name is passed to the
|
||||
/// inner VarBuilder.
|
||||
fn rename(&self, v: &str) -> std::borrow::Cow<'_, str>;
|
||||
}
|
||||
|
||||
pub struct Rename<'a, R: Renamer> {
|
||||
inner: VarBuilder<'a>,
|
||||
renamer: R,
|
||||
}
|
||||
|
||||
impl<'a, R: Renamer + Sync + Send> SimpleBackend for Rename<'a, R> {
|
||||
fn get(
|
||||
&self,
|
||||
s: Shape,
|
||||
name: &str,
|
||||
h: crate::Init,
|
||||
dtype: DType,
|
||||
dev: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let name = self.renamer.rename(name);
|
||||
self.inner
|
||||
.get_with_hints_dtype(s, &name, h, dtype)?
|
||||
.to_device(dev)
|
||||
}
|
||||
|
||||
fn contains_tensor(&self, name: &str) -> bool {
|
||||
let name = self.renamer.rename(name);
|
||||
self.inner.contains_tensor(&name)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, R: Renamer> Rename<'a, R> {
|
||||
pub fn new(inner: VarBuilder<'a>, renamer: R) -> Self {
|
||||
Self { inner, renamer }
|
||||
}
|
||||
}
|
||||
|
||||
impl Renamer for Box<dyn Fn(&str) -> String + Sync + Send> {
|
||||
fn rename(&self, v: &str) -> std::borrow::Cow<'_, str> {
|
||||
std::borrow::Cow::Owned(self(v))
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user