From 4c338b0cd91e1dbdb828dad861896faf1a719139 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 27 Aug 2023 18:03:26 +0100 Subject: [PATCH] VarBuilder cleanup (#627) * VarBuilder cleanup. * Implement the basic varbuilders. * Add the sharded code. * Proper support for tensor sharding. --- .../examples/llama_multiprocess/main.rs | 3 +- .../examples/llama_multiprocess/model.rs | 40 +- .../examples/mnist-training/main.rs | 4 +- .../examples/musicgen/encodec_model.rs | 2 +- candle-nn/src/batch_norm.rs | 8 +- candle-nn/src/conv.rs | 10 +- candle-nn/src/embedding.rs | 2 +- candle-nn/src/group_norm.rs | 4 +- candle-nn/src/init.rs | 6 + candle-nn/src/layer_norm.rs | 4 +- candle-nn/src/linear.rs | 6 +- candle-nn/src/var_builder.rs | 611 +++++++++++------- 12 files changed, 409 insertions(+), 291 deletions(-) diff --git a/candle-examples/examples/llama_multiprocess/main.rs b/candle-examples/examples/llama_multiprocess/main.rs index db315e46..17dc90e2 100644 --- a/candle-examples/examples/llama_multiprocess/main.rs +++ b/candle-examples/examples/llama_multiprocess/main.rs @@ -13,7 +13,6 @@ use anyhow::{bail, Error as E, Result}; use clap::Parser; use candle::{DType, Device, Tensor}; -use candle_nn::VarBuilder; use candle_transformers::generation::LogitsProcessor; use cudarc::driver::safe::CudaDevice; use cudarc::nccl::safe::{Comm, Id}; @@ -211,7 +210,7 @@ fn main() -> Result<()> { .map(|h| Ok(h.deserialize()?)) .collect::>>()?; - let vb = VarBuilder::from_safetensors(tensors, dtype, &device); + let vb = candle_nn::var_builder::ShardedSafeTensors::var_builder(tensors, dtype, &device); let llama = Llama::load(vb, &cache, &config, comm)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; diff --git a/candle-examples/examples/llama_multiprocess/model.rs b/candle-examples/examples/llama_multiprocess/model.rs index b146b42d..bb5c2368 100644 --- a/candle-examples/examples/llama_multiprocess/model.rs +++ b/candle-examples/examples/llama_multiprocess/model.rs @@ -1,6 +1,6 @@ use candle::backend::BackendStorage; use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D}; -use candle_nn::{rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder}; +use candle_nn::{Embedding, Linear, Module, RmsNorm}; use cudarc::nccl::safe::{Comm, ReduceOp}; use half::f16; use serde::Deserialize; @@ -9,6 +9,8 @@ use std::sync::{Arc, Mutex}; use super::MAX_SEQ_LEN; +use candle_nn::var_builder::ShardedVarBuilder as VarBuilder; + struct TensorParallelColumnLinear { linear: Linear, } @@ -82,11 +84,19 @@ impl TensorParallelRowLinear { } } +fn shard(dim: usize, rank: usize, world_size: usize) -> candle_nn::var_builder::Shard { + candle_nn::var_builder::Shard { + dim, + rank, + world_size, + } +} + impl TensorParallelColumnLinear { fn load(vb: VarBuilder, comm: Rc) -> Result { let rank = comm.rank(); let size = comm.world_size(); - let weight = vb.get_sharded("weight", 0, rank, size)?; + let weight = vb.get_with_hints((), "weight", shard(0, rank, size))?; Ok(Self::new(Linear::new(weight, None))) } @@ -95,8 +105,8 @@ impl TensorParallelColumnLinear { let size = comm.world_size(); let weights: Vec<_> = prefixes .iter() - .map(|p| vb.pp(p).get_sharded("weight", 0, rank, size).unwrap()) - .collect(); + .map(|p| vb.pp(p).get_with_hints((), "weight", shard(0, rank, size))) + .collect::>>()?; let weight = Tensor::cat(&weights, 0)?; Ok(Self::new(Linear::new(weight, None))) } @@ -106,7 +116,7 @@ impl TensorParallelRowLinear { fn load(vb: VarBuilder, comm: Rc) -> Result { let rank = comm.rank(); let size = comm.world_size(); - let weight = vb.get_sharded("weight", 1, rank, size)?; + let weight = vb.get_with_hints((), "weight", shard(1, rank, size))?; Ok(Self::new(Linear::new(weight, None), comm)) } } @@ -128,21 +138,6 @@ fn default_rope() -> f32 { 10_000.0 } -impl Config { - pub fn config_7b() -> Self { - Self { - intermediate_size: 11008, - vocab_size: 32000, - num_hidden_layers: 32, - num_attention_heads: 32, - hidden_size: 4096, - num_key_value_heads: 32, - rms_norm_eps: 1e-5, - rope_theta: 10_000.0, - } - } -} - #[derive(Clone)] pub struct Cache { #[allow(clippy::type_complexity)] @@ -352,6 +347,11 @@ struct Block { mlp: Mlp, } +fn rms_norm(size: usize, eps: f64, vb: VarBuilder) -> Result { + let weight = vb.get_with_hints(size, "weight", shard(0, 0, 1))?; + Ok(RmsNorm::new(weight, eps)) +} + impl Block { fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self { Self { diff --git a/candle-examples/examples/mnist-training/main.rs b/candle-examples/examples/mnist-training/main.rs index 9f985147..bcf8677d 100644 --- a/candle-examples/examples/mnist-training/main.rs +++ b/candle-examples/examples/mnist-training/main.rs @@ -14,8 +14,8 @@ const IMAGE_DIM: usize = 784; const LABELS: usize = 10; fn linear_z(in_dim: usize, out_dim: usize, vs: VarBuilder) -> Result { - let ws = vs.get_or_init((out_dim, in_dim), "weight", candle_nn::init::ZERO)?; - let bs = vs.get_or_init(out_dim, "bias", candle_nn::init::ZERO)?; + let ws = vs.get_with_hints((out_dim, in_dim), "weight", candle_nn::init::ZERO)?; + let bs = vs.get_with_hints(out_dim, "bias", candle_nn::init::ZERO)?; Ok(Linear::new(ws, Some(bs))) } diff --git a/candle-examples/examples/musicgen/encodec_model.rs b/candle-examples/examples/musicgen/encodec_model.rs index e7712bf3..86e3b6e9 100644 --- a/candle-examples/examples/musicgen/encodec_model.rs +++ b/candle-examples/examples/musicgen/encodec_model.rs @@ -368,7 +368,7 @@ impl<'a> Layer<'a> { self.cnt += 1; } - fn next(&mut self) -> VarBuilder<'a> { + fn next(&mut self) -> VarBuilder { let vb = self.vb.pp(&self.cnt.to_string()); self.cnt += 1; vb diff --git a/candle-nn/src/batch_norm.rs b/candle-nn/src/batch_norm.rs index dca3f60b..2dac0758 100644 --- a/candle-nn/src/batch_norm.rs +++ b/candle-nn/src/batch_norm.rs @@ -179,11 +179,11 @@ pub fn batch_norm>( if config.eps < 0. { candle::bail!("batch-norm eps cannot be negative {}", config.eps) } - let running_mean = vb.get_or_init(num_features, "running_mean", crate::Init::Const(0.))?; - let running_var = vb.get_or_init(num_features, "running_var", crate::Init::Const(1.))?; + let running_mean = vb.get_with_hints(num_features, "running_mean", crate::Init::Const(0.))?; + let running_var = vb.get_with_hints(num_features, "running_var", crate::Init::Const(1.))?; let weight_and_bias = if config.affine { - let weight = vb.get_or_init(num_features, "weight", crate::Init::Const(1.))?; - let bias = vb.get_or_init(num_features, "bias", crate::Init::Const(0.))?; + let weight = vb.get_with_hints(num_features, "weight", crate::Init::Const(1.))?; + let bias = vb.get_with_hints(num_features, "bias", crate::Init::Const(0.))?; Some((weight, bias)) } else { None diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs index 5c53c8da..e43de8ef 100644 --- a/candle-nn/src/conv.rs +++ b/candle-nn/src/conv.rs @@ -124,7 +124,7 @@ pub fn conv1d( vs: crate::VarBuilder, ) -> Result { let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; - let ws = vs.get_or_init( + let ws = vs.get_with_hints( (out_channels, in_channels / cfg.groups, kernel_size), "weight", init_ws, @@ -134,7 +134,7 @@ pub fn conv1d( lo: -bound, up: bound, }; - let bs = vs.get_or_init(out_channels, "bias", init_bs)?; + let bs = vs.get_with_hints(out_channels, "bias", init_bs)?; Ok(Conv1d::new(ws, Some(bs), cfg)) } @@ -146,7 +146,7 @@ pub fn conv2d( vs: crate::VarBuilder, ) -> Result { let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; - let ws = vs.get_or_init( + let ws = vs.get_with_hints( ( out_channels, in_channels / cfg.groups, @@ -161,7 +161,7 @@ pub fn conv2d( lo: -bound, up: bound, }; - let bs = vs.get_or_init(out_channels, "bias", init_bs)?; + let bs = vs.get_with_hints(out_channels, "bias", init_bs)?; Ok(Conv2d::new(ws, Some(bs), cfg)) } @@ -173,7 +173,7 @@ pub fn conv2d_no_bias( vs: crate::VarBuilder, ) -> Result { let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; - let ws = vs.get_or_init( + let ws = vs.get_with_hints( ( out_channels, in_channels / cfg.groups, diff --git a/candle-nn/src/embedding.rs b/candle-nn/src/embedding.rs index 918c1805..d84f9f53 100644 --- a/candle-nn/src/embedding.rs +++ b/candle-nn/src/embedding.rs @@ -32,7 +32,7 @@ impl crate::Module for Embedding { } pub fn embedding(in_size: usize, out_size: usize, vb: crate::VarBuilder) -> Result { - let embeddings = vb.get_or_init( + let embeddings = vb.get_with_hints( (in_size, out_size), "weight", crate::Init::Randn { diff --git a/candle-nn/src/group_norm.rs b/candle-nn/src/group_norm.rs index e85c4379..eb1b889f 100644 --- a/candle-nn/src/group_norm.rs +++ b/candle-nn/src/group_norm.rs @@ -79,7 +79,7 @@ pub fn group_norm( eps: f64, vb: crate::VarBuilder, ) -> Result { - let weight = vb.get_or_init(num_channels, "weight", crate::Init::Const(1.))?; - let bias = vb.get_or_init(num_channels, "bias", crate::Init::Const(0.))?; + let weight = vb.get_with_hints(num_channels, "weight", crate::Init::Const(1.))?; + let bias = vb.get_with_hints(num_channels, "bias", crate::Init::Const(0.))?; GroupNorm::new(weight, bias, num_channels, num_groups, eps) } diff --git a/candle-nn/src/init.rs b/candle-nn/src/init.rs index 25702d52..5b9c4fda 100644 --- a/candle-nn/src/init.rs +++ b/candle-nn/src/init.rs @@ -139,3 +139,9 @@ impl Init { } } } + +impl Default for Init { + fn default() -> Self { + Self::Const(0.) + } +} diff --git a/candle-nn/src/layer_norm.rs b/candle-nn/src/layer_norm.rs index 61fbe2d2..e4f556ab 100644 --- a/candle-nn/src/layer_norm.rs +++ b/candle-nn/src/layer_norm.rs @@ -128,9 +128,9 @@ pub fn layer_norm>( vb: crate::VarBuilder, ) -> Result { let config = config.into(); - let weight = vb.get_or_init(size, "weight", crate::Init::Const(1.))?; + let weight = vb.get_with_hints(size, "weight", crate::Init::Const(1.))?; let bias = if config.affine { - Some(vb.get_or_init(size, "bias", crate::Init::Const(0.))?) + Some(vb.get_with_hints(size, "bias", crate::Init::Const(0.))?) } else { None }; diff --git a/candle-nn/src/linear.rs b/candle-nn/src/linear.rs index a7bd1028..14250ed2 100644 --- a/candle-nn/src/linear.rs +++ b/candle-nn/src/linear.rs @@ -50,18 +50,18 @@ impl super::Module for Linear { /// This uses some default names for weight and biases, namely `"weight"` and `"bias"`. pub fn linear(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result { let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; - let ws = vs.get_or_init((out_dim, in_dim), "weight", init_ws)?; + let ws = vs.get_with_hints((out_dim, in_dim), "weight", init_ws)?; let bound = 1. / (in_dim as f64).sqrt(); let init_bs = crate::Init::Uniform { lo: -bound, up: bound, }; - let bs = vs.get_or_init(out_dim, "bias", init_bs)?; + let bs = vs.get_with_hints(out_dim, "bias", init_bs)?; Ok(Linear::new(ws, Some(bs))) } pub fn linear_no_bias(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result { let init_ws = crate::init::DEFAULT_KAIMING_NORMAL; - let ws = vs.get_or_init((out_dim, in_dim), "weight", init_ws)?; + let ws = vs.get_with_hints((out_dim, in_dim), "weight", init_ws)?; Ok(Linear::new(ws, None)) } diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index ef5b6fd1..c593960b 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -2,139 +2,105 @@ use crate::VarMap; use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor}; use safetensors::{slice::IndexOp, tensor::SafeTensors}; use std::collections::HashMap; -use std::sync::Arc; +use std::rc::Rc; -// TODO: Maybe we would want the storage to be generic, e.g. with Box to avoid too many -// generics. -enum Tensors<'a> { - SafeTensorWithRouting { - routing: HashMap, - safetensors: Vec>, - }, - Npz(candle::npy::NpzTensors), - TensorMap(HashMap), - Zeros, - VarMap(VarMap), +/// A structure used to retrieve variables, these variables can either come from storage or be +/// generated via some form of initialization. +/// +/// The way to retrieve variables is defined in the backend embedded in the `VarBuilder`. +pub struct VarBuilderArgs<'a, B: Backend> { + data: Rc>, + path: Vec, + _phantom: std::marker::PhantomData<&'a B>, } -struct TensorData<'a> { - tensors: Tensors<'a>, +impl<'a, B: Backend> Clone for VarBuilderArgs<'a, B> { + fn clone(&self) -> Self { + Self { + data: self.data.clone(), + path: self.path.clone(), + _phantom: self._phantom, + } + } +} + +/// A simple `VarBuilder`, this is less generic than `VarBuilderArgs` but should cover most common +/// use cases. +pub type VarBuilder<'a> = VarBuilderArgs<'a, Box>; + +struct TensorData { + backend: B, pub dtype: DType, pub device: Device, } -impl<'a> TensorData<'a> { - fn from_safetensors(safetensors: Vec>, dtype: DType, device: &Device) -> Self { - let mut routing = HashMap::new(); - for (index, sf) in safetensors.iter().enumerate() { - for k in sf.names() { - routing.insert(k.to_string(), index); - } - } - let tensors = Tensors::SafeTensorWithRouting { - routing, - safetensors, +/// A trait that defines how tensor data is retrieved. +/// +/// Typically this would use disk storage in some specific format, or random initialization. +/// Note that there is a speciliazed version of this trait (`SimpleBackend`) that can be used most +/// of the time. The main restriction is that it doesn't allow for specific args (besides +/// initialization hints). +pub trait Backend { + type Hints: Default; + + /// Retrieve a tensor with some target shape. + fn get( + &self, + s: Shape, + name: &str, + h: Self::Hints, + dtype: DType, + dev: &Device, + ) -> Result; +} + +pub trait SimpleBackend { + /// Retrieve a tensor based on a target name and shape. + fn get( + &self, + s: Shape, + name: &str, + h: crate::Init, + dtype: DType, + dev: &Device, + ) -> Result; +} + +impl<'a> Backend for Box { + type Hints = crate::Init; + fn get( + &self, + s: Shape, + name: &str, + h: Self::Hints, + dtype: DType, + dev: &Device, + ) -> Result { + self.as_ref().get(s, name, h, dtype, dev) + } +} + +impl<'a, B: Backend> VarBuilderArgs<'a, B> { + pub fn new_with_args(backend: B, dtype: DType, dev: &Device) -> Self { + let data = TensorData { + backend, + dtype, + device: dev.clone(), }; Self { - tensors, - device: device.clone(), - dtype, - } - } - - fn zeros(dtype: DType, device: &Device) -> Self { - Self { - tensors: Tensors::Zeros, - device: device.clone(), - dtype, - } - } - - fn from_tensors(tensors: HashMap, dtype: DType, device: &Device) -> Self { - Self { - tensors: Tensors::TensorMap(tensors), - device: device.clone(), - dtype, - } - } - - fn from_npz>(file: P, dtype: DType, device: &Device) -> Result { - let npz = candle::npy::NpzTensors::new(file)?; - Ok(Self { - tensors: Tensors::Npz(npz), - device: device.clone(), - dtype, - }) - } - - fn from_varmap(varmap: &VarMap, dtype: DType, device: &Device) -> Self { - Self { - tensors: Tensors::VarMap(varmap.clone()), - device: device.clone(), - dtype, - } - } -} - -#[derive(Clone)] -pub struct VarBuilder<'a> { - data: Arc>, - path: Vec, -} - -impl<'a> VarBuilder<'a> { - /// Create a `VarBuilder` accessing data frome the safetensors storage. The initial path is - /// set to the root path and sub-paths can be created via the `push_prefix` method. - pub fn from_safetensors(st: Vec>, dtype: DType, device: &Device) -> Self { - let data = TensorData::from_safetensors(st, dtype, device); - Self { - data: Arc::new(data), + data: Rc::new(data), path: vec![], + _phantom: std::marker::PhantomData, } } - pub fn zeros(dtype: DType, device: &Device) -> Self { - let data = TensorData::zeros(dtype, device); - Self { - data: Arc::new(data), - path: vec![], - } - } - - pub fn from_tensors(ts: HashMap, dtype: DType, device: &Device) -> Self { - let data = TensorData::from_tensors(ts, dtype, device); - Self { - data: Arc::new(data), - path: vec![], - } - } - - pub fn from_varmap(varmap: &VarMap, dtype: DType, device: &Device) -> Self { - let data = TensorData::from_varmap(varmap, dtype, device); - Self { - data: Arc::new(data), - path: vec![], - } - } - - pub fn from_npz>( - file: P, - dtype: DType, - device: &Device, - ) -> Result { - let data = TensorData::from_npz(file, dtype, device)?; - Ok(Self { - data: Arc::new(data), - path: vec![], - }) - } - pub fn push_prefix(&self, s: S) -> Self { let mut path = self.path.clone(); path.push(s.to_string()); Self { data: self.data.clone(), path, + _phantom: std::marker::PhantomData, } } @@ -150,130 +116,108 @@ impl<'a> VarBuilder<'a> { pub fn dtype(&self) -> DType { self.data.dtype } -} -impl<'a> VarBuilder<'a> { - /// Get part of a tensor, typically used to do Tensor Parallelism sharding. - /// - /// If the tensor is of size (1024, 1024). - /// - /// `dim` corresponds to the dimension to slice into - /// `rank` is the rank of the current process - /// `world_size` is the total number of ranks in the process group - /// - /// `get_sharded("tensor", 0, 0, 2)` means `tensor.i((..512))` - /// `get_sharded("tensor", 0, 1, 2)` means `tensor.i((512..))` - /// `get_sharded("tensor", 1, 0, 2)` means `tensor.i((.., ..512))` - pub fn get_sharded( - &self, - tensor_name: &str, - dim: usize, - rank: usize, - world_size: usize, - ) -> Result { - let data = self.data.as_ref(); - let path = self.path(tensor_name); - let tensor = match &self.data.tensors { - Tensors::SafeTensorWithRouting { - routing, - safetensors, - } => { - let index = routing.get(&path).ok_or_else(|| { - Error::CannotFindTensor { - path: path.to_string(), - } - .bt() - })?; - - let view = safetensors[*index].tensor(&path)?; - let dtype = view.dtype(); - let mut shape = view.shape().to_vec(); - let size = shape[dim]; - - if size % world_size != 0 { - return Err(Error::ShapeMismatchSplit { - shape: shape.into(), - dim, - n_parts: world_size, - }); - } - let block_size = size / world_size; - let start = rank * block_size; - let stop = (rank + 1) * block_size; - - // Everything is expressed in tensor dimension - // bytes offsets is handled automatically for safetensors. - - let iterator = if dim == 0 { - view.slice(start..stop).map_err(|_| Error::Msg(format!("Cannot slice tensor {tensor_name} ({shape:?} along dim {dim} with {start}..{stop}")))? - } else if dim == 1 { - view.slice((.., start..stop)).map_err(|_| Error::Msg(format!("Cannot slice tensor {tensor_name} ({shape:?} along dim {dim} with {start}..{stop}")))? - } else { - candle::bail!("Get sharded on dimensions != 0 or 1") - }; - - shape[dim] = block_size; - - let dtype: DType = dtype.try_into()?; - - let raw: Vec = iterator.into_iter().flatten().cloned().collect(); - Tensor::from_raw_buffer(&raw, dtype, &shape, &data.device)? - } - _ => candle::bail!("get_sharded is only available for safetensors"), - }; - Ok(tensor) + fn path(&self, tensor_name: &str) -> String { + if self.path.is_empty() { + tensor_name.to_string() + } else { + [&self.path.join("."), tensor_name].join(".") + } } /// Retrieve the tensor associated with the given name at the current path. - pub fn get>(&self, s: S, tensor_name: &str) -> Result { - let data = self.data.as_ref(); - let s: Shape = s.into(); - let path = self.path(tensor_name); - let tensor = match &self.data.tensors { - Tensors::Zeros => Tensor::zeros(&s, data.dtype, &data.device)?.contiguous()?, - Tensors::TensorMap(ts) => ts - .get(&path) - .ok_or_else(|| { - Error::CannotFindTensor { - path: path.to_string(), - } - .bt() - })? - .clone(), - Tensors::VarMap(varmap) => { - let data = varmap.data().lock().unwrap(); - data.get(&path) - .ok_or_else(|| { - Error::CannotFindTensor { - path: path.to_string(), - } - .bt() - })? - .as_tensor() - .clone() - } - Tensors::Npz(npz) => npz.get(&path)?.ok_or_else(|| { + pub fn get_with_hints>( + &self, + s: S, + name: &str, + hints: B::Hints, + ) -> Result { + let path = self.path(name); + self.data + .backend + .get(s.into(), &path, hints, self.data.dtype, &self.data.device) + } + + /// Retrieve the tensor associated with the given name at the current path. + pub fn get>(&self, s: S, name: &str) -> Result { + self.get_with_hints(s, name, Default::default()) + } +} + +struct Zeros; +impl SimpleBackend for Zeros { + fn get(&self, s: Shape, _: &str, _: crate::Init, dtype: DType, dev: &Device) -> Result { + Tensor::zeros(s, dtype, dev) + } +} + +impl SimpleBackend for HashMap { + fn get( + &self, + s: Shape, + name: &str, + _: crate::Init, + dtype: DType, + dev: &Device, + ) -> Result { + let tensor = self + .get(name) + .ok_or_else(|| { Error::CannotFindTensor { - path: path.to_string(), + path: name.to_string(), } .bt() - })?, - Tensors::SafeTensorWithRouting { - routing, - safetensors, - } => { - let index = routing.get(&path).ok_or_else(|| { - Error::CannotFindTensor { - path: path.to_string(), - } - .bt() - })?; - safetensors[*index] - .tensor(&path)? - .load(&data.device)? - .to_dtype(data.dtype)? + })? + .clone(); + if tensor.shape() != &s { + Err(candle::Error::UnexpectedShape { + msg: format!("shape mismatch for {name}"), + expected: s, + got: tensor.shape().clone(), } - }; + .bt())? + } + tensor.to_device(dev)?.to_dtype(dtype) + } +} + +impl SimpleBackend for VarMap { + fn get( + &self, + s: Shape, + name: &str, + h: crate::Init, + dtype: DType, + dev: &Device, + ) -> Result { + VarMap::get(self, s, name, h, dtype, dev) + } +} + +struct SafeTensorWithRouting<'a> { + routing: HashMap, + safetensors: Vec>, +} + +impl<'a> SimpleBackend for SafeTensorWithRouting<'a> { + fn get( + &self, + s: Shape, + path: &str, + _: crate::Init, + dtype: DType, + dev: &Device, + ) -> Result { + let index = self.routing.get(path).ok_or_else(|| { + Error::CannotFindTensor { + path: path.to_string(), + } + .bt() + })?; + let tensor = self.safetensors[*index] + .tensor(path)? + .load(dev)? + .to_dtype(dtype)?; if tensor.shape() != &s { Err(candle::Error::UnexpectedShape { msg: format!("shape mismatch for {path}"), @@ -284,32 +228,201 @@ impl<'a> VarBuilder<'a> { } Ok(tensor) } +} - /// Retrieve the tensor associated with the given name at the current path or initialize a new - /// tensor if it's missing. - /// - /// Tensor initialization is only available if the `VarBuilder` is backed by a `VarMap`. - pub fn get_or_init>( +impl SimpleBackend for candle::npy::NpzTensors { + fn get( &self, - s: S, - tensor_name: &str, - init: crate::Init, + s: Shape, + path: &str, + _: crate::Init, + dtype: DType, + dev: &Device, ) -> Result { - let data = self.data.as_ref(); - match &self.data.tensors { - Tensors::VarMap(varmap) => { - let path = self.path(tensor_name); - varmap.get(s, &path, init, data.dtype, &data.device) + let tensor = match self.get(path)? { + None => Err(Error::CannotFindTensor { + path: path.to_string(), } - _ => self.get(s, tensor_name), + .bt())?, + Some(tensor) => tensor, + }; + let tensor = tensor.to_device(dev)?.to_dtype(dtype)?; + if tensor.shape() != &s { + Err(candle::Error::UnexpectedShape { + msg: format!("shape mismatch for {path}"), + expected: s, + got: tensor.shape().clone(), + } + .bt())? + } + Ok(tensor) + } +} + +impl<'a> VarBuilder<'a> { + fn new(backend: Box, dtype: DType, device: Device) -> Self { + let data = TensorData { + backend, + dtype, + device, + }; + Self { + data: Rc::new(data), + path: vec![], + _phantom: std::marker::PhantomData, } } - fn path(&self, tensor_name: &str) -> String { - if self.path.is_empty() { - tensor_name.to_string() - } else { - [&self.path.join("."), tensor_name].join(".") + pub fn zeros(dtype: DType, dev: &Device) -> Self { + Self::new(Box::new(Zeros), dtype, dev.clone()) + } + + pub fn from_tensors(ts: HashMap, dtype: DType, dev: &Device) -> Self { + Self::new(Box::new(ts), dtype, dev.clone()) + } + + pub fn from_varmap(varmap: &VarMap, dtype: DType, dev: &Device) -> Self { + Self::new(Box::new(varmap.clone()), dtype, dev.clone()) + } + + pub fn from_safetensors(safetensors: Vec>, dtype: DType, dev: &Device) -> Self { + let mut routing = HashMap::new(); + for (index, sf) in safetensors.iter().enumerate() { + for k in sf.names() { + routing.insert(k.to_string(), index); + } + } + let tensors = SafeTensorWithRouting { + routing, + safetensors, + }; + Self::new(Box::new(tensors), dtype, dev.clone()) + } + + pub fn from_npz>(p: P, dtype: DType, dev: &Device) -> Result { + let npz = candle::npy::NpzTensors::new(p)?; + Ok(Self::new(Box::new(npz), dtype, dev.clone())) + } +} + +pub struct ShardedSafeTensors<'a>(SafeTensorWithRouting<'a>); +pub type ShardedVarBuilder<'a> = VarBuilderArgs<'a, ShardedSafeTensors<'a>>; + +impl<'a> ShardedSafeTensors<'a> { + pub fn var_builder( + safetensors: Vec>, + dtype: DType, + dev: &Device, + ) -> ShardedVarBuilder<'a> { + let mut routing = HashMap::new(); + for (index, sf) in safetensors.iter().enumerate() { + for k in sf.names() { + routing.insert(k.to_string(), index); + } + } + let tensors = SafeTensorWithRouting { + routing, + safetensors, + }; + let backend = ShardedSafeTensors(tensors); + VarBuilderArgs::new_with_args(backend, dtype, dev) + } +} + +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub struct Shard { + pub dim: usize, + pub rank: usize, + pub world_size: usize, +} + +impl Default for Shard { + fn default() -> Self { + Self { + dim: 0, + rank: 0, + world_size: 1, } } } + +/// Get part of a tensor, typically used to do Tensor Parallelism sharding. +/// +/// If the tensor is of size (1024, 1024). +/// +/// `dim` corresponds to the dimension to slice into +/// `rank` is the rank of the current process +/// `world_size` is the total number of ranks in the process group +/// +/// `get_sharded("tensor", 0, 0, 2)` means `tensor.i((..512))` +/// `get_sharded("tensor", 0, 1, 2)` means `tensor.i((512..))` +/// `get_sharded("tensor", 1, 0, 2)` means `tensor.i((.., ..512))` +impl<'a> Backend for ShardedSafeTensors<'a> { + type Hints = Shard; + + fn get( + &self, + _target_shape: Shape, // The size is not checked for ShardedTensors + path: &str, + h: Self::Hints, + dtype: DType, + dev: &Device, + ) -> Result { + let Shard { + dim, + rank, + world_size, + } = h; + let SafeTensorWithRouting { + routing, + safetensors, + } = &self.0; + let index = routing.get(path).ok_or_else(|| { + Error::CannotFindTensor { + path: path.to_string(), + } + .bt() + })?; + + let view = safetensors[*index].tensor(path)?; + let view_dtype = view.dtype(); + let mut shape = view.shape().to_vec(); + let size = shape[dim]; + + if size % world_size != 0 { + return Err(Error::ShapeMismatchSplit { + shape: shape.into(), + dim, + n_parts: world_size, + }); + } + let block_size = size / world_size; + let start = rank * block_size; + let stop = (rank + 1) * block_size; + + // Everything is expressed in tensor dimension + // bytes offsets is handled automatically for safetensors. + + let iterator = if dim == 0 { + view.slice(start..stop).map_err(|_| { + Error::Msg(format!( + "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}" + )) + })? + } else if dim == 1 { + view.slice((.., start..stop)).map_err(|_| { + Error::Msg(format!( + "Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}" + )) + })? + } else { + candle::bail!("Get sharded on dimensions != 0 or 1") + }; + + shape[dim] = block_size; + + let view_dtype: DType = view_dtype.try_into()?; + let raw: Vec = iterator.into_iter().flatten().cloned().collect(); + Tensor::from_raw_buffer(&raw, view_dtype, &shape, dev)?.to_dtype(dtype) + } +}