mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
VarBuilder cleanup (#627)
* VarBuilder cleanup. * Implement the basic varbuilders. * Add the sharded code. * Proper support for tensor sharding.
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D};
|
||||
use candle_nn::{rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
|
||||
use candle_nn::{Embedding, Linear, Module, RmsNorm};
|
||||
use cudarc::nccl::safe::{Comm, ReduceOp};
|
||||
use half::f16;
|
||||
use serde::Deserialize;
|
||||
@ -9,6 +9,8 @@ use std::sync::{Arc, Mutex};
|
||||
|
||||
use super::MAX_SEQ_LEN;
|
||||
|
||||
use candle_nn::var_builder::ShardedVarBuilder as VarBuilder;
|
||||
|
||||
struct TensorParallelColumnLinear {
|
||||
linear: Linear,
|
||||
}
|
||||
@ -82,11 +84,19 @@ impl TensorParallelRowLinear {
|
||||
}
|
||||
}
|
||||
|
||||
fn shard(dim: usize, rank: usize, world_size: usize) -> candle_nn::var_builder::Shard {
|
||||
candle_nn::var_builder::Shard {
|
||||
dim,
|
||||
rank,
|
||||
world_size,
|
||||
}
|
||||
}
|
||||
|
||||
impl TensorParallelColumnLinear {
|
||||
fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> {
|
||||
let rank = comm.rank();
|
||||
let size = comm.world_size();
|
||||
let weight = vb.get_sharded("weight", 0, rank, size)?;
|
||||
let weight = vb.get_with_hints((), "weight", shard(0, rank, size))?;
|
||||
Ok(Self::new(Linear::new(weight, None)))
|
||||
}
|
||||
|
||||
@ -95,8 +105,8 @@ impl TensorParallelColumnLinear {
|
||||
let size = comm.world_size();
|
||||
let weights: Vec<_> = prefixes
|
||||
.iter()
|
||||
.map(|p| vb.pp(p).get_sharded("weight", 0, rank, size).unwrap())
|
||||
.collect();
|
||||
.map(|p| vb.pp(p).get_with_hints((), "weight", shard(0, rank, size)))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let weight = Tensor::cat(&weights, 0)?;
|
||||
Ok(Self::new(Linear::new(weight, None)))
|
||||
}
|
||||
@ -106,7 +116,7 @@ impl TensorParallelRowLinear {
|
||||
fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> {
|
||||
let rank = comm.rank();
|
||||
let size = comm.world_size();
|
||||
let weight = vb.get_sharded("weight", 1, rank, size)?;
|
||||
let weight = vb.get_with_hints((), "weight", shard(1, rank, size))?;
|
||||
Ok(Self::new(Linear::new(weight, None), comm))
|
||||
}
|
||||
}
|
||||
@ -128,21 +138,6 @@ fn default_rope() -> f32 {
|
||||
10_000.0
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn config_7b() -> Self {
|
||||
Self {
|
||||
intermediate_size: 11008,
|
||||
vocab_size: 32000,
|
||||
num_hidden_layers: 32,
|
||||
num_attention_heads: 32,
|
||||
hidden_size: 4096,
|
||||
num_key_value_heads: 32,
|
||||
rms_norm_eps: 1e-5,
|
||||
rope_theta: 10_000.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Cache {
|
||||
#[allow(clippy::type_complexity)]
|
||||
@ -352,6 +347,11 @@ struct Block {
|
||||
mlp: Mlp,
|
||||
}
|
||||
|
||||
fn rms_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<RmsNorm> {
|
||||
let weight = vb.get_with_hints(size, "weight", shard(0, 0, 1))?;
|
||||
Ok(RmsNorm::new(weight, eps))
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
|
||||
Self {
|
||||
|
Reference in New Issue
Block a user