VarBuilder cleanup (#627)

* VarBuilder cleanup.

* Implement the basic varbuilders.

* Add the sharded code.

* Proper support for tensor sharding.
This commit is contained in:
Laurent Mazare
2023-08-27 18:03:26 +01:00
committed by GitHub
parent be471d50ab
commit 4c338b0cd9
12 changed files with 409 additions and 291 deletions

View File

@ -13,7 +13,6 @@ use anyhow::{bail, Error as E, Result};
use clap::Parser;
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use cudarc::driver::safe::CudaDevice;
use cudarc::nccl::safe::{Comm, Id};
@ -211,7 +210,7 @@ fn main() -> Result<()> {
.map(|h| Ok(h.deserialize()?))
.collect::<Result<Vec<_>>>()?;
let vb = VarBuilder::from_safetensors(tensors, dtype, &device);
let vb = candle_nn::var_builder::ShardedSafeTensors::var_builder(tensors, dtype, &device);
let llama = Llama::load(vb, &cache, &config, comm)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;

View File

@ -1,6 +1,6 @@
use candle::backend::BackendStorage;
use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D};
use candle_nn::{rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
use candle_nn::{Embedding, Linear, Module, RmsNorm};
use cudarc::nccl::safe::{Comm, ReduceOp};
use half::f16;
use serde::Deserialize;
@ -9,6 +9,8 @@ use std::sync::{Arc, Mutex};
use super::MAX_SEQ_LEN;
use candle_nn::var_builder::ShardedVarBuilder as VarBuilder;
struct TensorParallelColumnLinear {
linear: Linear,
}
@ -82,11 +84,19 @@ impl TensorParallelRowLinear {
}
}
fn shard(dim: usize, rank: usize, world_size: usize) -> candle_nn::var_builder::Shard {
candle_nn::var_builder::Shard {
dim,
rank,
world_size,
}
}
impl TensorParallelColumnLinear {
fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> {
let rank = comm.rank();
let size = comm.world_size();
let weight = vb.get_sharded("weight", 0, rank, size)?;
let weight = vb.get_with_hints((), "weight", shard(0, rank, size))?;
Ok(Self::new(Linear::new(weight, None)))
}
@ -95,8 +105,8 @@ impl TensorParallelColumnLinear {
let size = comm.world_size();
let weights: Vec<_> = prefixes
.iter()
.map(|p| vb.pp(p).get_sharded("weight", 0, rank, size).unwrap())
.collect();
.map(|p| vb.pp(p).get_with_hints((), "weight", shard(0, rank, size)))
.collect::<Result<Vec<_>>>()?;
let weight = Tensor::cat(&weights, 0)?;
Ok(Self::new(Linear::new(weight, None)))
}
@ -106,7 +116,7 @@ impl TensorParallelRowLinear {
fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> {
let rank = comm.rank();
let size = comm.world_size();
let weight = vb.get_sharded("weight", 1, rank, size)?;
let weight = vb.get_with_hints((), "weight", shard(1, rank, size))?;
Ok(Self::new(Linear::new(weight, None), comm))
}
}
@ -128,21 +138,6 @@ fn default_rope() -> f32 {
10_000.0
}
impl Config {
pub fn config_7b() -> Self {
Self {
intermediate_size: 11008,
vocab_size: 32000,
num_hidden_layers: 32,
num_attention_heads: 32,
hidden_size: 4096,
num_key_value_heads: 32,
rms_norm_eps: 1e-5,
rope_theta: 10_000.0,
}
}
}
#[derive(Clone)]
pub struct Cache {
#[allow(clippy::type_complexity)]
@ -352,6 +347,11 @@ struct Block {
mlp: Mlp,
}
fn rms_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<RmsNorm> {
let weight = vb.get_with_hints(size, "weight", shard(0, 0, 1))?;
Ok(RmsNorm::new(weight, eps))
}
impl Block {
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
Self {

View File

@ -14,8 +14,8 @@ const IMAGE_DIM: usize = 784;
const LABELS: usize = 10;
fn linear_z(in_dim: usize, out_dim: usize, vs: VarBuilder) -> Result<Linear> {
let ws = vs.get_or_init((out_dim, in_dim), "weight", candle_nn::init::ZERO)?;
let bs = vs.get_or_init(out_dim, "bias", candle_nn::init::ZERO)?;
let ws = vs.get_with_hints((out_dim, in_dim), "weight", candle_nn::init::ZERO)?;
let bs = vs.get_with_hints(out_dim, "bias", candle_nn::init::ZERO)?;
Ok(Linear::new(ws, Some(bs)))
}

View File

@ -368,7 +368,7 @@ impl<'a> Layer<'a> {
self.cnt += 1;
}
fn next(&mut self) -> VarBuilder<'a> {
fn next(&mut self) -> VarBuilder {
let vb = self.vb.pp(&self.cnt.to_string());
self.cnt += 1;
vb