[Proposal] Remove SafeTensor wrapper (allows finer control for users).

This commit is contained in:
Nicolas Patry
2023-07-19 16:25:44 +02:00
parent 67e20c3792
commit dfd624dbd3
4 changed files with 26 additions and 44 deletions

View File

@ -1,5 +1,6 @@
use crate::{DType, Device, Error, Result, Tensor, WithDType};
use safetensors::tensor as st;
pub use safetensors::tensor::SafeTensors;
use std::borrow::Cow;
impl From<DType> for st::Dtype {
@ -62,7 +63,7 @@ impl Tensor {
}
}
fn convert_<T: WithDType>(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
let v = view.data();
let size_in_bytes = T::DTYPE.size_in_bytes();
let elem_count = v.len() / size_in_bytes;
@ -101,7 +102,17 @@ fn convert_back_<T: WithDType>(mut vs: Vec<T>) -> Vec<u8> {
unsafe { Vec::from_raw_parts(ptr, length, capacity) }
}
pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result<Tensor> {
pub trait Load {
fn load(&self, device: &Device) -> Result<Tensor>;
}
impl<'a> Load for st::TensorView<'a> {
fn load(&self, device: &Device) -> Result<Tensor> {
convert(self, device)
}
}
pub fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
match view.dtype() {
st::Dtype::U8 => convert_::<u8>(view, device),
st::Dtype::U32 => convert_::<u8>(view, device),
@ -126,13 +137,6 @@ pub fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
}
}
// If Rust allowed for self-referential struct, we could store both the Mmap buffer and the
// SafeTensor bits in the same struct and avoid having the final users calling two methods.
// We could try using the ouroboros crate or equivalent for this at some point.
// Wrap the SafeTensors main module so as to provide accessors with the candle types for errors,
// dtypes, etc
pub struct SafeTensors<'a>(st::SafeTensors<'a>);
pub struct MmapedFile(memmap2::Mmap);
impl MmapedFile {
@ -150,33 +154,7 @@ impl MmapedFile {
pub fn deserialize(&self) -> Result<SafeTensors<'_>> {
let st = safetensors::SafeTensors::deserialize(&self.0)?;
Ok(SafeTensors(st))
}
}
impl<'a> SafeTensors<'a> {
pub fn from_buffer(buffer: &'a [u8]) -> Result<Self> {
let st = safetensors::SafeTensors::deserialize(buffer)?;
Ok(SafeTensors(st))
}
pub fn tensor(&self, name: &str, device: &Device) -> Result<Tensor> {
convert(self.0.tensor(name)?, device)
}
pub fn tensors(&self, device: &Device) -> Result<Vec<(String, Tensor)>> {
self.0
.tensors()
.into_iter()
.map(|(name, tensor_view)| {
let tensor = convert(tensor_view, device)?;
Ok((name, tensor))
})
.collect()
}
pub fn names(&self) -> Vec<&String> {
self.0.names()
Ok(st)
}
}

View File

@ -10,7 +10,7 @@
extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
use candle::{DType, Device, Tensor};
use candle::{safetensors::Load, DType, Device, Tensor};
use candle_hub::{api::sync::Api, Repo, RepoType};
use candle_nn::VarBuilder;
use clap::Parser;
@ -311,7 +311,7 @@ fn main() -> Result<()> {
let mel_filters = unsafe { candle::safetensors::MmapedFile::new(args.filters)? };
let mel_filters = mel_filters.deserialize()?;
let mel_filters = mel_filters.tensor("mel_80", &device)?;
let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?;
println!("loaded mel filters {:?}", mel_filters.shape());
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;

View File

@ -1,4 +1,7 @@
use candle::{safetensors::SafeTensors, DType, Device, Error, Result, Shape, Tensor};
use candle::{
safetensors::{Load, SafeTensors},
DType, Device, Error, Result, Shape, Tensor,
};
use std::collections::HashMap;
use std::sync::Arc;
@ -170,7 +173,8 @@ impl<'a> VarBuilder<'a> {
.bt()
})?;
safetensors[*index]
.tensor(&path, &data.device)?
.tensor(&path)?
.load(&data.device)?
.to_dtype(data.dtype)?
}
};

View File

@ -1,6 +1,6 @@
use crate::model::{Config, Whisper};
use anyhow::Error as E;
use candle::{DType, Device, Tensor};
use candle::{safetensors::Load, DType, Device, Tensor};
use candle_nn::VarBuilder;
use rand::{distributions::Distribution, rngs::StdRng, SeedableRng};
use serde::{Deserialize, Serialize};
@ -236,11 +236,11 @@ impl Decoder {
let device = Device::Cpu;
let tokenizer = Tokenizer::from_bytes(&md.tokenizer).map_err(anyhow::Error::msg)?;
let mel_filters = candle::safetensors::SafeTensors::from_buffer(&md.mel_filters)?;
let mel_filters = mel_filters.tensor("mel_80", &device)?;
let mel_filters = candle::safetensors::SafeTensors::deserialize(&md.mel_filters)?;
let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?;
console_log!("loaded mel filters {:?}", mel_filters.shape());
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
let weights = candle::safetensors::SafeTensors::from_buffer(&md.weights)?;
let weights = candle::safetensors::SafeTensors::deserialize(&md.weights)?;
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
let config = Config::tiny_en();
let whisper = Whisper::load(&vb, config)?;