From dfd624dbd3f1ba85f37e59307f5ca7cbd16fe903 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 19 Jul 2023 16:25:44 +0200 Subject: [PATCH] [Proposal] Remove SafeTensor wrapper (allows finer control for users). --- candle-core/src/safetensors.rs | 50 +++++++----------------- candle-examples/examples/whisper/main.rs | 4 +- candle-nn/src/var_builder.rs | 8 +++- candle-wasm-example/src/worker.rs | 8 ++-- 4 files changed, 26 insertions(+), 44 deletions(-) diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index 6ef709ce..3bb069a9 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -1,5 +1,6 @@ use crate::{DType, Device, Error, Result, Tensor, WithDType}; use safetensors::tensor as st; +pub use safetensors::tensor::SafeTensors; use std::borrow::Cow; impl From for st::Dtype { @@ -62,7 +63,7 @@ impl Tensor { } } -fn convert_(view: st::TensorView<'_>, device: &Device) -> Result { +fn convert_(view: &st::TensorView<'_>, device: &Device) -> Result { let v = view.data(); let size_in_bytes = T::DTYPE.size_in_bytes(); let elem_count = v.len() / size_in_bytes; @@ -101,7 +102,17 @@ fn convert_back_(mut vs: Vec) -> Vec { unsafe { Vec::from_raw_parts(ptr, length, capacity) } } -pub fn convert(view: st::TensorView<'_>, device: &Device) -> Result { +pub trait Load { + fn load(&self, device: &Device) -> Result; +} + +impl<'a> Load for st::TensorView<'a> { + fn load(&self, device: &Device) -> Result { + convert(self, device) + } +} + +pub fn convert(view: &st::TensorView<'_>, device: &Device) -> Result { match view.dtype() { st::Dtype::U8 => convert_::(view, device), st::Dtype::U32 => convert_::(view, device), @@ -126,13 +137,6 @@ pub fn convert_back(tensor: &Tensor) -> Result> { } } -// If Rust allowed for self-referential struct, we could store both the Mmap buffer and the -// SafeTensor bits in the same struct and avoid having the final users calling two methods. -// We could try using the ouroboros crate or equivalent for this at some point. -// Wrap the SafeTensors main module so as to provide accessors with the candle types for errors, -// dtypes, etc -pub struct SafeTensors<'a>(st::SafeTensors<'a>); - pub struct MmapedFile(memmap2::Mmap); impl MmapedFile { @@ -150,33 +154,7 @@ impl MmapedFile { pub fn deserialize(&self) -> Result> { let st = safetensors::SafeTensors::deserialize(&self.0)?; - Ok(SafeTensors(st)) - } -} - -impl<'a> SafeTensors<'a> { - pub fn from_buffer(buffer: &'a [u8]) -> Result { - let st = safetensors::SafeTensors::deserialize(buffer)?; - Ok(SafeTensors(st)) - } - - pub fn tensor(&self, name: &str, device: &Device) -> Result { - convert(self.0.tensor(name)?, device) - } - - pub fn tensors(&self, device: &Device) -> Result> { - self.0 - .tensors() - .into_iter() - .map(|(name, tensor_view)| { - let tensor = convert(tensor_view, device)?; - Ok((name, tensor)) - }) - .collect() - } - - pub fn names(&self) -> Vec<&String> { - self.0.names() + Ok(st) } } diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index d01fb605..c8e42c72 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -10,7 +10,7 @@ extern crate intel_mkl_src; use anyhow::{Error as E, Result}; -use candle::{DType, Device, Tensor}; +use candle::{safetensors::Load, DType, Device, Tensor}; use candle_hub::{api::sync::Api, Repo, RepoType}; use candle_nn::VarBuilder; use clap::Parser; @@ -311,7 +311,7 @@ fn main() -> Result<()> { let mel_filters = unsafe { candle::safetensors::MmapedFile::new(args.filters)? }; let mel_filters = mel_filters.deserialize()?; - let mel_filters = mel_filters.tensor("mel_80", &device)?; + let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?; println!("loaded mel filters {:?}", mel_filters.shape()); let mel_filters = mel_filters.flatten_all()?.to_vec1::()?; diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index a6cb53e5..87dd2a7f 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -1,4 +1,7 @@ -use candle::{safetensors::SafeTensors, DType, Device, Error, Result, Shape, Tensor}; +use candle::{ + safetensors::{Load, SafeTensors}, + DType, Device, Error, Result, Shape, Tensor, +}; use std::collections::HashMap; use std::sync::Arc; @@ -170,7 +173,8 @@ impl<'a> VarBuilder<'a> { .bt() })?; safetensors[*index] - .tensor(&path, &data.device)? + .tensor(&path)? + .load(&data.device)? .to_dtype(data.dtype)? } }; diff --git a/candle-wasm-example/src/worker.rs b/candle-wasm-example/src/worker.rs index 7b9ffbec..5001e7e4 100644 --- a/candle-wasm-example/src/worker.rs +++ b/candle-wasm-example/src/worker.rs @@ -1,6 +1,6 @@ use crate::model::{Config, Whisper}; use anyhow::Error as E; -use candle::{DType, Device, Tensor}; +use candle::{safetensors::Load, DType, Device, Tensor}; use candle_nn::VarBuilder; use rand::{distributions::Distribution, rngs::StdRng, SeedableRng}; use serde::{Deserialize, Serialize}; @@ -236,11 +236,11 @@ impl Decoder { let device = Device::Cpu; let tokenizer = Tokenizer::from_bytes(&md.tokenizer).map_err(anyhow::Error::msg)?; - let mel_filters = candle::safetensors::SafeTensors::from_buffer(&md.mel_filters)?; - let mel_filters = mel_filters.tensor("mel_80", &device)?; + let mel_filters = candle::safetensors::SafeTensors::deserialize(&md.mel_filters)?; + let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?; console_log!("loaded mel filters {:?}", mel_filters.shape()); let mel_filters = mel_filters.flatten_all()?.to_vec1::()?; - let weights = candle::safetensors::SafeTensors::from_buffer(&md.weights)?; + let weights = candle::safetensors::SafeTensors::deserialize(&md.weights)?; let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device); let config = Config::tiny_en(); let whisper = Whisper::load(&vb, config)?;