mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
[Proposal] Remove SafeTensor wrapper (allows finer control for users).
This commit is contained in:
@ -1,5 +1,6 @@
|
|||||||
use crate::{DType, Device, Error, Result, Tensor, WithDType};
|
use crate::{DType, Device, Error, Result, Tensor, WithDType};
|
||||||
use safetensors::tensor as st;
|
use safetensors::tensor as st;
|
||||||
|
pub use safetensors::tensor::SafeTensors;
|
||||||
use std::borrow::Cow;
|
use std::borrow::Cow;
|
||||||
|
|
||||||
impl From<DType> for st::Dtype {
|
impl From<DType> for st::Dtype {
|
||||||
@ -62,7 +63,7 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn convert_<T: WithDType>(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||||
let v = view.data();
|
let v = view.data();
|
||||||
let size_in_bytes = T::DTYPE.size_in_bytes();
|
let size_in_bytes = T::DTYPE.size_in_bytes();
|
||||||
let elem_count = v.len() / size_in_bytes;
|
let elem_count = v.len() / size_in_bytes;
|
||||||
@ -101,7 +102,17 @@ fn convert_back_<T: WithDType>(mut vs: Vec<T>) -> Vec<u8> {
|
|||||||
unsafe { Vec::from_raw_parts(ptr, length, capacity) }
|
unsafe { Vec::from_raw_parts(ptr, length, capacity) }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
pub trait Load {
|
||||||
|
fn load(&self, device: &Device) -> Result<Tensor>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Load for st::TensorView<'a> {
|
||||||
|
fn load(&self, device: &Device) -> Result<Tensor> {
|
||||||
|
convert(self, device)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||||
match view.dtype() {
|
match view.dtype() {
|
||||||
st::Dtype::U8 => convert_::<u8>(view, device),
|
st::Dtype::U8 => convert_::<u8>(view, device),
|
||||||
st::Dtype::U32 => convert_::<u8>(view, device),
|
st::Dtype::U32 => convert_::<u8>(view, device),
|
||||||
@ -126,13 +137,6 @@ pub fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If Rust allowed for self-referential struct, we could store both the Mmap buffer and the
|
|
||||||
// SafeTensor bits in the same struct and avoid having the final users calling two methods.
|
|
||||||
// We could try using the ouroboros crate or equivalent for this at some point.
|
|
||||||
// Wrap the SafeTensors main module so as to provide accessors with the candle types for errors,
|
|
||||||
// dtypes, etc
|
|
||||||
pub struct SafeTensors<'a>(st::SafeTensors<'a>);
|
|
||||||
|
|
||||||
pub struct MmapedFile(memmap2::Mmap);
|
pub struct MmapedFile(memmap2::Mmap);
|
||||||
|
|
||||||
impl MmapedFile {
|
impl MmapedFile {
|
||||||
@ -150,33 +154,7 @@ impl MmapedFile {
|
|||||||
|
|
||||||
pub fn deserialize(&self) -> Result<SafeTensors<'_>> {
|
pub fn deserialize(&self) -> Result<SafeTensors<'_>> {
|
||||||
let st = safetensors::SafeTensors::deserialize(&self.0)?;
|
let st = safetensors::SafeTensors::deserialize(&self.0)?;
|
||||||
Ok(SafeTensors(st))
|
Ok(st)
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a> SafeTensors<'a> {
|
|
||||||
pub fn from_buffer(buffer: &'a [u8]) -> Result<Self> {
|
|
||||||
let st = safetensors::SafeTensors::deserialize(buffer)?;
|
|
||||||
Ok(SafeTensors(st))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn tensor(&self, name: &str, device: &Device) -> Result<Tensor> {
|
|
||||||
convert(self.0.tensor(name)?, device)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn tensors(&self, device: &Device) -> Result<Vec<(String, Tensor)>> {
|
|
||||||
self.0
|
|
||||||
.tensors()
|
|
||||||
.into_iter()
|
|
||||||
.map(|(name, tensor_view)| {
|
|
||||||
let tensor = convert(tensor_view, device)?;
|
|
||||||
Ok((name, tensor))
|
|
||||||
})
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn names(&self) -> Vec<&String> {
|
|
||||||
self.0.names()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@
|
|||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{safetensors::Load, DType, Device, Tensor};
|
||||||
use candle_hub::{api::sync::Api, Repo, RepoType};
|
use candle_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
@ -311,7 +311,7 @@ fn main() -> Result<()> {
|
|||||||
|
|
||||||
let mel_filters = unsafe { candle::safetensors::MmapedFile::new(args.filters)? };
|
let mel_filters = unsafe { candle::safetensors::MmapedFile::new(args.filters)? };
|
||||||
let mel_filters = mel_filters.deserialize()?;
|
let mel_filters = mel_filters.deserialize()?;
|
||||||
let mel_filters = mel_filters.tensor("mel_80", &device)?;
|
let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?;
|
||||||
println!("loaded mel filters {:?}", mel_filters.shape());
|
println!("loaded mel filters {:?}", mel_filters.shape());
|
||||||
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
|
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
|
||||||
|
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
use candle::{safetensors::SafeTensors, DType, Device, Error, Result, Shape, Tensor};
|
use candle::{
|
||||||
|
safetensors::{Load, SafeTensors},
|
||||||
|
DType, Device, Error, Result, Shape, Tensor,
|
||||||
|
};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
@ -170,7 +173,8 @@ impl<'a> VarBuilder<'a> {
|
|||||||
.bt()
|
.bt()
|
||||||
})?;
|
})?;
|
||||||
safetensors[*index]
|
safetensors[*index]
|
||||||
.tensor(&path, &data.device)?
|
.tensor(&path)?
|
||||||
|
.load(&data.device)?
|
||||||
.to_dtype(data.dtype)?
|
.to_dtype(data.dtype)?
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
use crate::model::{Config, Whisper};
|
use crate::model::{Config, Whisper};
|
||||||
use anyhow::Error as E;
|
use anyhow::Error as E;
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{safetensors::Load, DType, Device, Tensor};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use rand::{distributions::Distribution, rngs::StdRng, SeedableRng};
|
use rand::{distributions::Distribution, rngs::StdRng, SeedableRng};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
@ -236,11 +236,11 @@ impl Decoder {
|
|||||||
let device = Device::Cpu;
|
let device = Device::Cpu;
|
||||||
let tokenizer = Tokenizer::from_bytes(&md.tokenizer).map_err(anyhow::Error::msg)?;
|
let tokenizer = Tokenizer::from_bytes(&md.tokenizer).map_err(anyhow::Error::msg)?;
|
||||||
|
|
||||||
let mel_filters = candle::safetensors::SafeTensors::from_buffer(&md.mel_filters)?;
|
let mel_filters = candle::safetensors::SafeTensors::deserialize(&md.mel_filters)?;
|
||||||
let mel_filters = mel_filters.tensor("mel_80", &device)?;
|
let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?;
|
||||||
console_log!("loaded mel filters {:?}", mel_filters.shape());
|
console_log!("loaded mel filters {:?}", mel_filters.shape());
|
||||||
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
|
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
|
||||||
let weights = candle::safetensors::SafeTensors::from_buffer(&md.weights)?;
|
let weights = candle::safetensors::SafeTensors::deserialize(&md.weights)?;
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
|
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
|
||||||
let config = Config::tiny_en();
|
let config = Config::tiny_en();
|
||||||
let whisper = Whisper::load(&vb, config)?;
|
let whisper = Whisper::load(&vb, config)?;
|
||||||
|
Reference in New Issue
Block a user