mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
VarBuilder cleanup (#627)
* VarBuilder cleanup. * Implement the basic varbuilders. * Add the sharded code. * Proper support for tensor sharding.
This commit is contained in:
@ -13,7 +13,6 @@ use anyhow::{bail, Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use cudarc::driver::safe::CudaDevice;
|
||||
use cudarc::nccl::safe::{Comm, Id};
|
||||
@ -211,7 +210,7 @@ fn main() -> Result<()> {
|
||||
.map(|h| Ok(h.deserialize()?))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let vb = VarBuilder::from_safetensors(tensors, dtype, &device);
|
||||
let vb = candle_nn::var_builder::ShardedSafeTensors::var_builder(tensors, dtype, &device);
|
||||
let llama = Llama::load(vb, &cache, &config, comm)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D};
|
||||
use candle_nn::{rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
|
||||
use candle_nn::{Embedding, Linear, Module, RmsNorm};
|
||||
use cudarc::nccl::safe::{Comm, ReduceOp};
|
||||
use half::f16;
|
||||
use serde::Deserialize;
|
||||
@ -9,6 +9,8 @@ use std::sync::{Arc, Mutex};
|
||||
|
||||
use super::MAX_SEQ_LEN;
|
||||
|
||||
use candle_nn::var_builder::ShardedVarBuilder as VarBuilder;
|
||||
|
||||
struct TensorParallelColumnLinear {
|
||||
linear: Linear,
|
||||
}
|
||||
@ -82,11 +84,19 @@ impl TensorParallelRowLinear {
|
||||
}
|
||||
}
|
||||
|
||||
fn shard(dim: usize, rank: usize, world_size: usize) -> candle_nn::var_builder::Shard {
|
||||
candle_nn::var_builder::Shard {
|
||||
dim,
|
||||
rank,
|
||||
world_size,
|
||||
}
|
||||
}
|
||||
|
||||
impl TensorParallelColumnLinear {
|
||||
fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> {
|
||||
let rank = comm.rank();
|
||||
let size = comm.world_size();
|
||||
let weight = vb.get_sharded("weight", 0, rank, size)?;
|
||||
let weight = vb.get_with_hints((), "weight", shard(0, rank, size))?;
|
||||
Ok(Self::new(Linear::new(weight, None)))
|
||||
}
|
||||
|
||||
@ -95,8 +105,8 @@ impl TensorParallelColumnLinear {
|
||||
let size = comm.world_size();
|
||||
let weights: Vec<_> = prefixes
|
||||
.iter()
|
||||
.map(|p| vb.pp(p).get_sharded("weight", 0, rank, size).unwrap())
|
||||
.collect();
|
||||
.map(|p| vb.pp(p).get_with_hints((), "weight", shard(0, rank, size)))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let weight = Tensor::cat(&weights, 0)?;
|
||||
Ok(Self::new(Linear::new(weight, None)))
|
||||
}
|
||||
@ -106,7 +116,7 @@ impl TensorParallelRowLinear {
|
||||
fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> {
|
||||
let rank = comm.rank();
|
||||
let size = comm.world_size();
|
||||
let weight = vb.get_sharded("weight", 1, rank, size)?;
|
||||
let weight = vb.get_with_hints((), "weight", shard(1, rank, size))?;
|
||||
Ok(Self::new(Linear::new(weight, None), comm))
|
||||
}
|
||||
}
|
||||
@ -128,21 +138,6 @@ fn default_rope() -> f32 {
|
||||
10_000.0
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn config_7b() -> Self {
|
||||
Self {
|
||||
intermediate_size: 11008,
|
||||
vocab_size: 32000,
|
||||
num_hidden_layers: 32,
|
||||
num_attention_heads: 32,
|
||||
hidden_size: 4096,
|
||||
num_key_value_heads: 32,
|
||||
rms_norm_eps: 1e-5,
|
||||
rope_theta: 10_000.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Cache {
|
||||
#[allow(clippy::type_complexity)]
|
||||
@ -352,6 +347,11 @@ struct Block {
|
||||
mlp: Mlp,
|
||||
}
|
||||
|
||||
fn rms_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<RmsNorm> {
|
||||
let weight = vb.get_with_hints(size, "weight", shard(0, 0, 1))?;
|
||||
Ok(RmsNorm::new(weight, eps))
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
|
||||
Self {
|
||||
|
@ -14,8 +14,8 @@ const IMAGE_DIM: usize = 784;
|
||||
const LABELS: usize = 10;
|
||||
|
||||
fn linear_z(in_dim: usize, out_dim: usize, vs: VarBuilder) -> Result<Linear> {
|
||||
let ws = vs.get_or_init((out_dim, in_dim), "weight", candle_nn::init::ZERO)?;
|
||||
let bs = vs.get_or_init(out_dim, "bias", candle_nn::init::ZERO)?;
|
||||
let ws = vs.get_with_hints((out_dim, in_dim), "weight", candle_nn::init::ZERO)?;
|
||||
let bs = vs.get_with_hints(out_dim, "bias", candle_nn::init::ZERO)?;
|
||||
Ok(Linear::new(ws, Some(bs)))
|
||||
}
|
||||
|
||||
|
@ -368,7 +368,7 @@ impl<'a> Layer<'a> {
|
||||
self.cnt += 1;
|
||||
}
|
||||
|
||||
fn next(&mut self) -> VarBuilder<'a> {
|
||||
fn next(&mut self) -> VarBuilder {
|
||||
let vb = self.vb.pp(&self.cnt.to_string());
|
||||
self.cnt += 1;
|
||||
vb
|
||||
|
Reference in New Issue
Block a user