mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Self-contained safetensor wrappers (#946)
* Self-contained safetensor wrappers. * Use the new safetensor container in varbuilders.
This commit is contained in:
@ -321,6 +321,18 @@ impl MmapedSafetensors {
|
||||
}
|
||||
|
||||
pub fn load(&self, name: &str, dev: &Device) -> Result<Tensor> {
|
||||
self.get(name)?.load(dev)
|
||||
}
|
||||
|
||||
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
|
||||
let mut tensors = vec![];
|
||||
for safetensors in self.safetensors.iter() {
|
||||
tensors.push(safetensors.get().0.tensors())
|
||||
}
|
||||
tensors.into_iter().flatten().collect()
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Result<st::TensorView<'_>> {
|
||||
let index = match &self.routing {
|
||||
None => 0,
|
||||
Some(routing) => {
|
||||
@ -333,15 +345,7 @@ impl MmapedSafetensors {
|
||||
*index
|
||||
}
|
||||
};
|
||||
self.safetensors[index].get().0.tensor(name)?.load(dev)
|
||||
}
|
||||
|
||||
pub fn tensors(&self) -> Vec<(String, st::TensorView<'_>)> {
|
||||
let mut tensors = vec![];
|
||||
for safetensors in self.safetensors.iter() {
|
||||
tensors.push(safetensors.get().0.tensors())
|
||||
}
|
||||
tensors.into_iter().flatten().collect()
|
||||
Ok(self.safetensors[index].get().0.tensor(name)?)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -122,30 +122,16 @@ impl T5ModelBuilder {
|
||||
}
|
||||
|
||||
pub fn build_encoder(&self) -> Result<t5::T5EncoderModel> {
|
||||
let weights = self
|
||||
.weights_filename
|
||||
.iter()
|
||||
.map(|f| unsafe { candle::safetensors::MmapedFile::new(f) })
|
||||
.collect::<candle::Result<Vec<_>>>()?;
|
||||
let weights = weights
|
||||
.iter()
|
||||
.map(|w| w.deserialize())
|
||||
.collect::<candle::Result<Vec<_>>>()?;
|
||||
let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device);
|
||||
let vb = unsafe {
|
||||
VarBuilder::from_mmaped_safetensors(&self.weights_filename, DTYPE, &self.device)?
|
||||
};
|
||||
Ok(t5::T5EncoderModel::load(vb, &self.config)?)
|
||||
}
|
||||
|
||||
pub fn build_conditional_generation(&self) -> Result<t5::T5ForConditionalGeneration> {
|
||||
let weights = self
|
||||
.weights_filename
|
||||
.iter()
|
||||
.map(|f| unsafe { candle::safetensors::MmapedFile::new(f) })
|
||||
.collect::<candle::Result<Vec<_>>>()?;
|
||||
let weights = weights
|
||||
.iter()
|
||||
.map(|w| w.deserialize())
|
||||
.collect::<candle::Result<Vec<_>>>()?;
|
||||
let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device);
|
||||
let vb = unsafe {
|
||||
VarBuilder::from_mmaped_safetensors(&self.weights_filename, DTYPE, &self.device)?
|
||||
};
|
||||
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
|
||||
}
|
||||
}
|
||||
|
@ -325,6 +325,32 @@ impl SimpleBackend for candle::npy::NpzTensors {
|
||||
}
|
||||
}
|
||||
|
||||
impl SimpleBackend for candle::safetensors::MmapedSafetensors {
|
||||
fn get(
|
||||
&self,
|
||||
s: Shape,
|
||||
name: &str,
|
||||
_: crate::Init,
|
||||
dtype: DType,
|
||||
dev: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let tensor = self.load(name, dev)?.to_dtype(dtype)?;
|
||||
if tensor.shape() != &s {
|
||||
Err(candle::Error::UnexpectedShape {
|
||||
msg: format!("shape mismatch for {name}"),
|
||||
expected: s,
|
||||
got: tensor.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
Ok(tensor)
|
||||
}
|
||||
|
||||
fn contains_tensor(&self, name: &str) -> bool {
|
||||
self.get(name).is_ok()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> VarBuilder<'a> {
|
||||
fn new(backend: Box<dyn SimpleBackend + 'a>, dtype: DType, device: Device) -> Self {
|
||||
let data = TensorData {
|
||||
@ -361,7 +387,7 @@ impl<'a> VarBuilder<'a> {
|
||||
}
|
||||
|
||||
/// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors
|
||||
/// files.
|
||||
/// data.
|
||||
pub fn from_safetensors(safetensors: Vec<SafeTensors<'a>>, dtype: DType, dev: &Device) -> Self {
|
||||
let mut routing = HashMap::new();
|
||||
for (index, sf) in safetensors.iter().enumerate() {
|
||||
@ -376,6 +402,21 @@ impl<'a> VarBuilder<'a> {
|
||||
Self::new(Box::new(tensors), dtype, dev.clone())
|
||||
}
|
||||
|
||||
/// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors
|
||||
/// files.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// The unsafe is inherited from [`memmap2::MmapOptions`].
|
||||
pub unsafe fn from_mmaped_safetensors<P: AsRef<std::path::Path>>(
|
||||
paths: &[P],
|
||||
dtype: DType,
|
||||
dev: &Device,
|
||||
) -> Result<Self> {
|
||||
let tensors = candle::safetensors::MmapedSafetensors::multi(paths)?;
|
||||
Ok(Self::new(Box::new(tensors), dtype, dev.clone()))
|
||||
}
|
||||
|
||||
/// Initializes a `VarBuilder` that retrieves tensors stored in a numpy npz file.
|
||||
pub fn from_npz<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
|
||||
let npz = candle::npy::NpzTensors::new(p)?;
|
||||
|
Reference in New Issue
Block a user