diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index e81fe184..dee57b37 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -1,7 +1,6 @@ use crate::{DType, Device, Error, Result, Tensor, WithDType}; -use safetensors::slice::SliceIterator; use safetensors::tensor as st; -use safetensors::tensor::{Dtype, SafeTensors}; +use safetensors::tensor::SafeTensors; use std::borrow::Cow; impl From for st::Dtype { @@ -118,26 +117,24 @@ impl<'a> Load for st::TensorView<'a> { } impl Tensor { - pub fn from_safetensors_slice( - iterator: SliceIterator, - dtype: Dtype, + pub fn from_raw_buffer( + data: &[u8], + dtype: DType, shape: &[usize], device: &Device, ) -> Result { - let data: Vec = iterator.into_iter().flatten().cloned().collect(); match dtype { - st::Dtype::U8 => convert_slice::(&data, shape, device), - st::Dtype::U32 => convert_slice::(&data, shape, device), - st::Dtype::BF16 => convert_slice::(&data, shape, device), - st::Dtype::F16 => convert_slice::(&data, shape, device), - st::Dtype::F32 => convert_slice::(&data, shape, device), - st::Dtype::F64 => convert_slice::(&data, shape, device), - dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)), + DType::U8 => convert_slice::(data, shape, device), + DType::U32 => convert_slice::(data, shape, device), + DType::BF16 => convert_slice::(data, shape, device), + DType::F16 => convert_slice::(data, shape, device), + DType::F32 => convert_slice::(data, shape, device), + DType::F64 => convert_slice::(data, shape, device), } } } -pub fn convert(view: &st::TensorView<'_>, device: &Device) -> Result { +fn convert(view: &st::TensorView<'_>, device: &Device) -> Result { match view.dtype() { st::Dtype::U8 => convert_::(view, device), st::Dtype::U32 => convert_::(view, device), @@ -149,7 +146,7 @@ pub fn convert(view: &st::TensorView<'_>, device: &Device) -> Result { } } -pub fn convert_back(tensor: &Tensor) -> Result> { +fn convert_back(tensor: &Tensor) -> Result> { // TODO: This makes an unnecessary copy when the tensor is on the cpu. let tensor = tensor.flatten_all()?; match tensor.dtype() { diff --git a/candle-examples/examples/llama_multiprocess/model.rs b/candle-examples/examples/llama_multiprocess/model.rs index e902734f..becaa879 100644 --- a/candle-examples/examples/llama_multiprocess/model.rs +++ b/candle-examples/examples/llama_multiprocess/model.rs @@ -4,7 +4,6 @@ use candle_nn::{Embedding, Linear, VarBuilder}; use cudarc::nccl::safe::{Comm, ReduceOp}; use half::f16; use std::collections::HashMap; -use std::rc::Rc; use std::sync::{Arc, Mutex}; use super::MAX_SEQ_LEN; @@ -24,11 +23,11 @@ impl TensorParallelColumnLinear { struct TensorParallelRowLinear { linear: Linear, - comm: Rc, + comm: Arc, } struct AllReduce { - comm: Rc, + comm: Arc, } impl CustomOp1 for AllReduce { @@ -61,12 +60,12 @@ impl CustomOp1 for AllReduce { } } -fn all_reduce_sum(x: &Tensor, comm: &Rc) -> Result { +fn all_reduce_sum(x: &Tensor, comm: &Arc) -> Result { x.custom_op1(AllReduce { comm: comm.clone() }) } impl TensorParallelRowLinear { - fn new(linear: Linear, comm: Rc) -> Self { + fn new(linear: Linear, comm: Arc) -> Self { Self { linear, comm } } fn forward(&self, x: &Tensor) -> Result { @@ -76,14 +75,14 @@ impl TensorParallelRowLinear { } impl TensorParallelColumnLinear { - fn load(vb: VarBuilder, comm: Rc) -> Result { + fn load(vb: VarBuilder, comm: Arc) -> Result { let rank = comm.rank(); let size = comm.world_size(); let weight = vb.get_sharded("weight", 0, rank, size)?; Ok(Self::new(Linear::new(weight, None))) } - fn load_multi(vb: VarBuilder, prefixes: &[&str], comm: Rc) -> Result { + fn load_multi(vb: VarBuilder, prefixes: &[&str], comm: Arc) -> Result { let rank = comm.rank(); let size = comm.world_size(); let weights: Vec<_> = prefixes @@ -96,7 +95,7 @@ impl TensorParallelColumnLinear { } impl TensorParallelRowLinear { - fn load(vb: VarBuilder, comm: Rc) -> Result { + fn load(vb: VarBuilder, comm: Arc) -> Result { let rank = comm.rank(); let size = comm.world_size(); let weight = vb.get_sharded("weight", 1, rank, size)?; @@ -339,7 +338,7 @@ impl CausalSelfAttention { } } - fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc) -> Result { + fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc) -> Result { let qkv_proj = TensorParallelColumnLinear::load_multi( vb.clone(), &["q_proj", "k_proj", "v_proj"], @@ -388,7 +387,7 @@ impl Mlp { self.c_proj.forward(&x) } - fn load(vb: VarBuilder, _cfg: &Config, comm: Rc) -> Result { + fn load(vb: VarBuilder, _cfg: &Config, comm: Arc) -> Result { let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?; let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?; let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm.clone())?; @@ -422,7 +421,7 @@ impl Block { Ok(x) } - fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc) -> Result { + fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc) -> Result { let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg, comm.clone())?; let mlp = Mlp::load(vb.pp("mlp"), cfg, comm.clone())?; let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?; @@ -466,7 +465,7 @@ impl Llama { logits.to_dtype(DType::F32) } - pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc) -> Result { + pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc) -> Result { let wte = embedding(cfg, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?; diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index b02d216b..1466f6d0 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -1,6 +1,5 @@ use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor}; -use safetensors::slice::IndexOp; -use safetensors::tensor::SafeTensors; +use safetensors::{slice::IndexOp, tensor::SafeTensors}; use std::collections::HashMap; use std::sync::Arc; @@ -70,7 +69,7 @@ impl<'a> TensorData<'a> { #[derive(Clone)] pub struct VarBuilder<'a> { data: Arc>, - pub path: Vec, + path: Vec, } impl<'a> VarBuilder<'a> { @@ -179,7 +178,10 @@ impl<'a> VarBuilder<'a> { shape[dim] = block_size; - Tensor::from_safetensors_slice(iterator, dtype, &shape, &data.device)? + let dtype: DType = dtype.try_into()?; + + let raw: Vec = iterator.into_iter().flatten().cloned().collect(); + Tensor::from_raw_buffer(&raw, dtype, &shape, &data.device)? } _ => unimplemented!(), }; diff --git a/candle-wasm-examples/whisper/Cargo.toml b/candle-wasm-examples/whisper/Cargo.toml index 4ebb2788..b51d4052 100644 --- a/candle-wasm-examples/whisper/Cargo.toml +++ b/candle-wasm-examples/whisper/Cargo.toml @@ -15,7 +15,6 @@ candle = { path = "../../candle-core" } candle-nn = { path = "../../candle-nn" } num-traits = { workspace = true } tokenizers = { workspace = true, features = ["unstable_wasm"] } -safetensors = { workspace = true } # App crates. anyhow = { workspace = true } @@ -24,6 +23,7 @@ rand = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } wav = { workspace = true } +safetensors = { workspace = true } # Wasm specific crates. getrandom = { version = "0.2", features = ["js"] }