Removing inner dependency on safetensors.

This commit is contained in:
Nicolas Patry
2023-07-26 11:16:04 +02:00
parent 1553b58fe5
commit 7c7e6ba201
4 changed files with 30 additions and 32 deletions

View File

@ -4,7 +4,6 @@ use candle_nn::{Embedding, Linear, VarBuilder};
use cudarc::nccl::safe::{Comm, ReduceOp};
use half::f16;
use std::collections::HashMap;
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use super::MAX_SEQ_LEN;
@ -24,11 +23,11 @@ impl TensorParallelColumnLinear {
struct TensorParallelRowLinear {
linear: Linear,
comm: Rc<Comm>,
comm: Arc<Comm>,
}
struct AllReduce {
comm: Rc<Comm>,
comm: Arc<Comm>,
}
impl CustomOp1 for AllReduce {
@ -61,12 +60,12 @@ impl CustomOp1 for AllReduce {
}
}
fn all_reduce_sum(x: &Tensor, comm: &Rc<Comm>) -> Result<Tensor> {
fn all_reduce_sum(x: &Tensor, comm: &Arc<Comm>) -> Result<Tensor> {
x.custom_op1(AllReduce { comm: comm.clone() })
}
impl TensorParallelRowLinear {
fn new(linear: Linear, comm: Rc<Comm>) -> Self {
fn new(linear: Linear, comm: Arc<Comm>) -> Self {
Self { linear, comm }
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
@ -76,14 +75,14 @@ impl TensorParallelRowLinear {
}
impl TensorParallelColumnLinear {
fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> {
fn load(vb: VarBuilder, comm: Arc<Comm>) -> Result<Self> {
let rank = comm.rank();
let size = comm.world_size();
let weight = vb.get_sharded("weight", 0, rank, size)?;
Ok(Self::new(Linear::new(weight, None)))
}
fn load_multi(vb: VarBuilder, prefixes: &[&str], comm: Rc<Comm>) -> Result<Self> {
fn load_multi(vb: VarBuilder, prefixes: &[&str], comm: Arc<Comm>) -> Result<Self> {
let rank = comm.rank();
let size = comm.world_size();
let weights: Vec<_> = prefixes
@ -96,7 +95,7 @@ impl TensorParallelColumnLinear {
}
impl TensorParallelRowLinear {
fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> {
fn load(vb: VarBuilder, comm: Arc<Comm>) -> Result<Self> {
let rank = comm.rank();
let size = comm.world_size();
let weight = vb.get_sharded("weight", 1, rank, size)?;
@ -339,7 +338,7 @@ impl CausalSelfAttention {
}
}
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc<Comm>) -> Result<Self> {
let qkv_proj = TensorParallelColumnLinear::load_multi(
vb.clone(),
&["q_proj", "k_proj", "v_proj"],
@ -388,7 +387,7 @@ impl Mlp {
self.c_proj.forward(&x)
}
fn load(vb: VarBuilder, _cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
fn load(vb: VarBuilder, _cfg: &Config, comm: Arc<Comm>) -> Result<Self> {
let c_fc1 = TensorParallelColumnLinear::load(vb.pp("gate_proj"), comm.clone())?;
let c_fc2 = TensorParallelColumnLinear::load(vb.pp("up_proj"), comm.clone())?;
let c_proj = TensorParallelRowLinear::load(vb.pp("down_proj"), comm.clone())?;
@ -422,7 +421,7 @@ impl Block {
Ok(x)
}
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc<Comm>) -> Result<Self> {
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg, comm.clone())?;
let mlp = Mlp::load(vb.pp("mlp"), cfg, comm.clone())?;
let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?;
@ -466,7 +465,7 @@ impl Llama {
logits.to_dtype(DType::F32)
}
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Rc<Comm>) -> Result<Self> {
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config, comm: Arc<Comm>) -> Result<Self> {
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?;