mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
[Proposal] Remove SafeTensor wrapper (allows finer control for users).
This commit is contained in:
@ -1,5 +1,6 @@
|
||||
use crate::{DType, Device, Error, Result, Tensor, WithDType};
|
||||
use safetensors::tensor as st;
|
||||
pub use safetensors::tensor::SafeTensors;
|
||||
use std::borrow::Cow;
|
||||
|
||||
impl From<DType> for st::Dtype {
|
||||
@ -62,7 +63,7 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_<T: WithDType>(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
let v = view.data();
|
||||
let size_in_bytes = T::DTYPE.size_in_bytes();
|
||||
let elem_count = v.len() / size_in_bytes;
|
||||
@ -101,7 +102,17 @@ fn convert_back_<T: WithDType>(mut vs: Vec<T>) -> Vec<u8> {
|
||||
unsafe { Vec::from_raw_parts(ptr, length, capacity) }
|
||||
}
|
||||
|
||||
pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
pub trait Load {
|
||||
fn load(&self, device: &Device) -> Result<Tensor>;
|
||||
}
|
||||
|
||||
impl<'a> Load for st::TensorView<'a> {
|
||||
fn load(&self, device: &Device) -> Result<Tensor> {
|
||||
convert(self, device)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
match view.dtype() {
|
||||
st::Dtype::U8 => convert_::<u8>(view, device),
|
||||
st::Dtype::U32 => convert_::<u8>(view, device),
|
||||
@ -126,13 +137,6 @@ pub fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
|
||||
}
|
||||
}
|
||||
|
||||
// If Rust allowed for self-referential struct, we could store both the Mmap buffer and the
|
||||
// SafeTensor bits in the same struct and avoid having the final users calling two methods.
|
||||
// We could try using the ouroboros crate or equivalent for this at some point.
|
||||
// Wrap the SafeTensors main module so as to provide accessors with the candle types for errors,
|
||||
// dtypes, etc
|
||||
pub struct SafeTensors<'a>(st::SafeTensors<'a>);
|
||||
|
||||
pub struct MmapedFile(memmap2::Mmap);
|
||||
|
||||
impl MmapedFile {
|
||||
@ -150,33 +154,7 @@ impl MmapedFile {
|
||||
|
||||
pub fn deserialize(&self) -> Result<SafeTensors<'_>> {
|
||||
let st = safetensors::SafeTensors::deserialize(&self.0)?;
|
||||
Ok(SafeTensors(st))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> SafeTensors<'a> {
|
||||
pub fn from_buffer(buffer: &'a [u8]) -> Result<Self> {
|
||||
let st = safetensors::SafeTensors::deserialize(buffer)?;
|
||||
Ok(SafeTensors(st))
|
||||
}
|
||||
|
||||
pub fn tensor(&self, name: &str, device: &Device) -> Result<Tensor> {
|
||||
convert(self.0.tensor(name)?, device)
|
||||
}
|
||||
|
||||
pub fn tensors(&self, device: &Device) -> Result<Vec<(String, Tensor)>> {
|
||||
self.0
|
||||
.tensors()
|
||||
.into_iter()
|
||||
.map(|(name, tensor_view)| {
|
||||
let tensor = convert(tensor_view, device)?;
|
||||
Ok((name, tensor))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn names(&self) -> Vec<&String> {
|
||||
self.0.names()
|
||||
Ok(st)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -10,7 +10,7 @@
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle::{safetensors::Load, DType, Device, Tensor};
|
||||
use candle_hub::{api::sync::Api, Repo, RepoType};
|
||||
use candle_nn::VarBuilder;
|
||||
use clap::Parser;
|
||||
@ -311,7 +311,7 @@ fn main() -> Result<()> {
|
||||
|
||||
let mel_filters = unsafe { candle::safetensors::MmapedFile::new(args.filters)? };
|
||||
let mel_filters = mel_filters.deserialize()?;
|
||||
let mel_filters = mel_filters.tensor("mel_80", &device)?;
|
||||
let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?;
|
||||
println!("loaded mel filters {:?}", mel_filters.shape());
|
||||
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
|
||||
|
||||
|
@ -1,4 +1,7 @@
|
||||
use candle::{safetensors::SafeTensors, DType, Device, Error, Result, Shape, Tensor};
|
||||
use candle::{
|
||||
safetensors::{Load, SafeTensors},
|
||||
DType, Device, Error, Result, Shape, Tensor,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
@ -170,7 +173,8 @@ impl<'a> VarBuilder<'a> {
|
||||
.bt()
|
||||
})?;
|
||||
safetensors[*index]
|
||||
.tensor(&path, &data.device)?
|
||||
.tensor(&path)?
|
||||
.load(&data.device)?
|
||||
.to_dtype(data.dtype)?
|
||||
}
|
||||
};
|
||||
|
@ -1,6 +1,6 @@
|
||||
use crate::model::{Config, Whisper};
|
||||
use anyhow::Error as E;
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle::{safetensors::Load, DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use rand::{distributions::Distribution, rngs::StdRng, SeedableRng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
@ -236,11 +236,11 @@ impl Decoder {
|
||||
let device = Device::Cpu;
|
||||
let tokenizer = Tokenizer::from_bytes(&md.tokenizer).map_err(anyhow::Error::msg)?;
|
||||
|
||||
let mel_filters = candle::safetensors::SafeTensors::from_buffer(&md.mel_filters)?;
|
||||
let mel_filters = mel_filters.tensor("mel_80", &device)?;
|
||||
let mel_filters = candle::safetensors::SafeTensors::deserialize(&md.mel_filters)?;
|
||||
let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?;
|
||||
console_log!("loaded mel filters {:?}", mel_filters.shape());
|
||||
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
|
||||
let weights = candle::safetensors::SafeTensors::from_buffer(&md.weights)?;
|
||||
let weights = candle::safetensors::SafeTensors::deserialize(&md.weights)?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
|
||||
let config = Config::tiny_en();
|
||||
let whisper = Whisper::load(&vb, config)?;
|
||||
|
Reference in New Issue
Block a user