mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Removing inner dependency on safetensors.
This commit is contained in:
@ -1,7 +1,6 @@
|
|||||||
use crate::{DType, Device, Error, Result, Tensor, WithDType};
|
use crate::{DType, Device, Error, Result, Tensor, WithDType};
|
||||||
use safetensors::slice::SliceIterator;
|
|
||||||
use safetensors::tensor as st;
|
use safetensors::tensor as st;
|
||||||
use safetensors::tensor::{Dtype, SafeTensors};
|
use safetensors::tensor::SafeTensors;
|
||||||
use std::borrow::Cow;
|
use std::borrow::Cow;
|
||||||
|
|
||||||
impl From<DType> for st::Dtype {
|
impl From<DType> for st::Dtype {
|
||||||
@ -118,26 +117,24 @@ impl<'a> Load for st::TensorView<'a> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Tensor {
|
impl Tensor {
|
||||||
pub fn from_safetensors_slice(
|
pub fn from_raw_buffer(
|
||||||
iterator: SliceIterator,
|
data: &[u8],
|
||||||
dtype: Dtype,
|
dtype: DType,
|
||||||
shape: &[usize],
|
shape: &[usize],
|
||||||
device: &Device,
|
device: &Device,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let data: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
|
|
||||||
match dtype {
|
match dtype {
|
||||||
st::Dtype::U8 => convert_slice::<u8>(&data, shape, device),
|
DType::U8 => convert_slice::<u8>(data, shape, device),
|
||||||
st::Dtype::U32 => convert_slice::<u8>(&data, shape, device),
|
DType::U32 => convert_slice::<u32>(data, shape, device),
|
||||||
st::Dtype::BF16 => convert_slice::<half::bf16>(&data, shape, device),
|
DType::BF16 => convert_slice::<half::bf16>(data, shape, device),
|
||||||
st::Dtype::F16 => convert_slice::<half::f16>(&data, shape, device),
|
DType::F16 => convert_slice::<half::f16>(data, shape, device),
|
||||||
st::Dtype::F32 => convert_slice::<f32>(&data, shape, device),
|
DType::F32 => convert_slice::<f32>(data, shape, device),
|
||||||
st::Dtype::F64 => convert_slice::<f64>(&data, shape, device),
|
DType::F64 => convert_slice::<f64>(data, shape, device),
|
||||||
dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||||
match view.dtype() {
|
match view.dtype() {
|
||||||
st::Dtype::U8 => convert_::<u8>(view, device),
|
st::Dtype::U8 => convert_::<u8>(view, device),
|
||||||
st::Dtype::U32 => convert_::<u8>(view, device),
|
st::Dtype::U32 => convert_::<u8>(view, device),
|
||||||
@ -149,7 +146,7 @@ pub fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
|
fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
|
||||||
// TODO: This makes an unnecessary copy when the tensor is on the cpu.
|
// TODO: This makes an unnecessary copy when the tensor is on the cpu.
|
||||||
let tensor = tensor.flatten_all()?;
|
let tensor = tensor.flatten_all()?;
|
||||||
match tensor.dtype() {
|
match tensor.dtype() {
|
||||||
|
@ -4,7 +4,6 @@ use candle_nn::{Embedding, Linear, VarBuilder};
|
|||||||
use cudarc::nccl::safe::{Comm, ReduceOp};
|
use cudarc::nccl::safe::{Comm, ReduceOp};
|
||||||
use half::f16;
|
use half::f16;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::rc::Rc;
|
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
use super::MAX_SEQ_LEN;
|
use super::MAX_SEQ_LEN;
|
||||||
@ -24,11 +23,11 @@ impl TensorParallelColumnLinear {
|
|||||||
|
|
||||||
struct TensorParallelRowLinear {
|
struct TensorParallelRowLinear {
|
||||||
linear: Linear,
|
linear: Linear,
|
||||||
comm: Rc<Comm>,
|
comm: Arc<Comm>,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct AllReduce {
|
struct AllReduce {
|
||||||
comm: Rc<Comm>,
|
comm: Arc<Comm>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CustomOp1 for AllReduce {
|
impl CustomOp1 for AllReduce {
|
||||||
@ -61,12 +60,12 @@ impl CustomOp1 for AllReduce {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn all_reduce_sum(x: &Tensor, comm: &Rc<Comm>) -> Result<Tensor> {
|
fn all_reduce_sum(x: &Tensor, comm: &Arc<Comm>) -> Result<Tensor> {
|
||||||
x.custom_op1(AllReduce { comm: comm.clone() })
|
x.custom_op1(AllReduce { comm: comm.clone() })
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TensorParallelRowLinear {
|
impl TensorParallelRowLinear {
|
||||||
fn new(linear: Linear, comm: Rc<Comm>) -> Self {
|
fn new(linear: Linear, comm: Arc<Comm>) -> Self {
|
||||||
Self { linear, comm }
|
Self { linear, comm }
|
||||||
}
|
}
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
@ -76,14 +75,14 @@ impl TensorParallelRowLinear {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl TensorParallelColumnLinear {
|
impl TensorParallelColumnLinear {
|
||||||
fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> {
|
fn load(vb: VarBuilder, comm: Arc<Comm>) -> Result<Self> {
|
||||||
let rank = comm.rank();
|
let rank = comm.rank();
|
||||||
let size = comm.world_size();
|
let size = comm.world_size();
|
||||||
let weight = vb.get_sharded("weight", 0, rank, size)?;
|
let weight = vb.get_sharded("weight", 0, rank, size)?;
|
||||||
Ok(Self::new(Linear::new(weight, None)))
|
Ok(Self::new(Linear::new(weight, None)))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load_multi(vb: VarBuilder, prefixes: &[&str], comm: Rc<Comm>) -> Result<Self> {
|
fn load_multi(vb: VarBuilder, prefixes: &[&str], comm: Arc<Comm>) -> Result<Self> {
|
||||||
let rank = comm.rank();
|
let rank = comm.rank();
|
||||||
let size = comm.world_size();
|
let size = comm.world_size();
|
||||||
let weights: Vec<_> = prefixes
|
let weights: Vec<_> = prefixes
|
||||||
@ -96,7 +95,7 @@ impl TensorParallelColumnLinear {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl TensorParallelRowLinear {
|
impl TensorParallelRowLinear {
|
||||||
fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> {
|
fn load(vb: VarBuilder, comm: Arc<Comm>) -> Result<Self> {
|
||||||
let rank = comm.rank();
|
let rank = comm.rank();
|
||||||
let size = comm.world_size();
|
let size = comm.world_size();
|
||||||
let weight = vb.get_sharded("weight", 1, rank, size)?;
|
let weight = vb.get_sharded("weight", 1, rank, size)?;
|
||||||
@ -339,7 +338,7 @@ impl CausalSelfAttention {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
|
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc<Comm>) -> Result<Self> {
|
||||||
let qkv_proj = TensorParallelColumnLinear::load_multi(
|
let qkv_proj = TensorParallelColumnLinear::load_multi(
|
||||||
vb.clone(),
|
vb.clone(),
|
||||||
&["q_proj", "k_proj", "v_proj"],
|
&["q_proj", "k_proj", "v_proj"],
|
||||||
@ -388,7 +387,7 @@ impl Mlp {
|
|||||||
self.c_proj.forward(&x)
|
self.c_proj.forward(&x)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load(vb: VarBuilder, _cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
|
fn load(vb: VarBuilder, _cfg: &Config, comm: Arc<Comm>) -> Result<Self> {
|
||||||
let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?;
|
let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?;
|
||||||
let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?;
|
let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?;
|
||||||
let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm.clone())?;
|
let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm.clone())?;
|
||||||
@ -422,7 +421,7 @@ impl Block {
|
|||||||
Ok(x)
|
Ok(x)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
|
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc<Comm>) -> Result<Self> {
|
||||||
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg, comm.clone())?;
|
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg, comm.clone())?;
|
||||||
let mlp = Mlp::load(vb.pp("mlp"), cfg, comm.clone())?;
|
let mlp = Mlp::load(vb.pp("mlp"), cfg, comm.clone())?;
|
||||||
let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?;
|
let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?;
|
||||||
@ -466,7 +465,7 @@ impl Llama {
|
|||||||
logits.to_dtype(DType::F32)
|
logits.to_dtype(DType::F32)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc<Comm>) -> Result<Self> {
|
||||||
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
|
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
|
||||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||||
let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?;
|
let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?;
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor};
|
use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor};
|
||||||
use safetensors::slice::IndexOp;
|
use safetensors::{slice::IndexOp, tensor::SafeTensors};
|
||||||
use safetensors::tensor::SafeTensors;
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
@ -70,7 +69,7 @@ impl<'a> TensorData<'a> {
|
|||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct VarBuilder<'a> {
|
pub struct VarBuilder<'a> {
|
||||||
data: Arc<TensorData<'a>>,
|
data: Arc<TensorData<'a>>,
|
||||||
pub path: Vec<String>,
|
path: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> VarBuilder<'a> {
|
impl<'a> VarBuilder<'a> {
|
||||||
@ -179,7 +178,10 @@ impl<'a> VarBuilder<'a> {
|
|||||||
|
|
||||||
shape[dim] = block_size;
|
shape[dim] = block_size;
|
||||||
|
|
||||||
Tensor::from_safetensors_slice(iterator, dtype, &shape, &data.device)?
|
let dtype: DType = dtype.try_into()?;
|
||||||
|
|
||||||
|
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
|
||||||
|
Tensor::from_raw_buffer(&raw, dtype, &shape, &data.device)?
|
||||||
}
|
}
|
||||||
_ => unimplemented!(),
|
_ => unimplemented!(),
|
||||||
};
|
};
|
||||||
|
@ -15,7 +15,6 @@ candle = { path = "../../candle-core" }
|
|||||||
candle-nn = { path = "../../candle-nn" }
|
candle-nn = { path = "../../candle-nn" }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
tokenizers = { workspace = true, features = ["unstable_wasm"] }
|
||||||
safetensors = { workspace = true }
|
|
||||||
|
|
||||||
# App crates.
|
# App crates.
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
@ -24,6 +23,7 @@ rand = { workspace = true }
|
|||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
wav = { workspace = true }
|
wav = { workspace = true }
|
||||||
|
safetensors = { workspace = true }
|
||||||
|
|
||||||
# Wasm specific crates.
|
# Wasm specific crates.
|
||||||
getrandom = { version = "0.2", features = ["js"] }
|
getrandom = { version = "0.2", features = ["js"] }
|
||||||
|
Reference in New Issue
Block a user