mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
VarBuilder cleanup (#627)
* VarBuilder cleanup. * Implement the basic varbuilders. * Add the sharded code. * Proper support for tensor sharding.
This commit is contained in:
@ -13,7 +13,6 @@ use anyhow::{bail, Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use cudarc::driver::safe::CudaDevice;
|
||||
use cudarc::nccl::safe::{Comm, Id};
|
||||
@ -211,7 +210,7 @@ fn main() -> Result<()> {
|
||||
.map(|h| Ok(h.deserialize()?))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let vb = VarBuilder::from_safetensors(tensors, dtype, &device);
|
||||
let vb = candle_nn::var_builder::ShardedSafeTensors::var_builder(tensors, dtype, &device);
|
||||
let llama = Llama::load(vb, &cache, &config, comm)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shape, Tensor, D};
|
||||
use candle_nn::{rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
|
||||
use candle_nn::{Embedding, Linear, Module, RmsNorm};
|
||||
use cudarc::nccl::safe::{Comm, ReduceOp};
|
||||
use half::f16;
|
||||
use serde::Deserialize;
|
||||
@ -9,6 +9,8 @@ use std::sync::{Arc, Mutex};
|
||||
|
||||
use super::MAX_SEQ_LEN;
|
||||
|
||||
use candle_nn::var_builder::ShardedVarBuilder as VarBuilder;
|
||||
|
||||
struct TensorParallelColumnLinear {
|
||||
linear: Linear,
|
||||
}
|
||||
@ -82,11 +84,19 @@ impl TensorParallelRowLinear {
|
||||
}
|
||||
}
|
||||
|
||||
fn shard(dim: usize, rank: usize, world_size: usize) -> candle_nn::var_builder::Shard {
|
||||
candle_nn::var_builder::Shard {
|
||||
dim,
|
||||
rank,
|
||||
world_size,
|
||||
}
|
||||
}
|
||||
|
||||
impl TensorParallelColumnLinear {
|
||||
fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> {
|
||||
let rank = comm.rank();
|
||||
let size = comm.world_size();
|
||||
let weight = vb.get_sharded("weight", 0, rank, size)?;
|
||||
let weight = vb.get_with_hints((), "weight", shard(0, rank, size))?;
|
||||
Ok(Self::new(Linear::new(weight, None)))
|
||||
}
|
||||
|
||||
@ -95,8 +105,8 @@ impl TensorParallelColumnLinear {
|
||||
let size = comm.world_size();
|
||||
let weights: Vec<_> = prefixes
|
||||
.iter()
|
||||
.map(|p| vb.pp(p).get_sharded("weight", 0, rank, size).unwrap())
|
||||
.collect();
|
||||
.map(|p| vb.pp(p).get_with_hints((), "weight", shard(0, rank, size)))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let weight = Tensor::cat(&weights, 0)?;
|
||||
Ok(Self::new(Linear::new(weight, None)))
|
||||
}
|
||||
@ -106,7 +116,7 @@ impl TensorParallelRowLinear {
|
||||
fn load(vb: VarBuilder, comm: Rc<Comm>) -> Result<Self> {
|
||||
let rank = comm.rank();
|
||||
let size = comm.world_size();
|
||||
let weight = vb.get_sharded("weight", 1, rank, size)?;
|
||||
let weight = vb.get_with_hints((), "weight", shard(1, rank, size))?;
|
||||
Ok(Self::new(Linear::new(weight, None), comm))
|
||||
}
|
||||
}
|
||||
@ -128,21 +138,6 @@ fn default_rope() -> f32 {
|
||||
10_000.0
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn config_7b() -> Self {
|
||||
Self {
|
||||
intermediate_size: 11008,
|
||||
vocab_size: 32000,
|
||||
num_hidden_layers: 32,
|
||||
num_attention_heads: 32,
|
||||
hidden_size: 4096,
|
||||
num_key_value_heads: 32,
|
||||
rms_norm_eps: 1e-5,
|
||||
rope_theta: 10_000.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Cache {
|
||||
#[allow(clippy::type_complexity)]
|
||||
@ -352,6 +347,11 @@ struct Block {
|
||||
mlp: Mlp,
|
||||
}
|
||||
|
||||
fn rms_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<RmsNorm> {
|
||||
let weight = vb.get_with_hints(size, "weight", shard(0, 0, 1))?;
|
||||
Ok(RmsNorm::new(weight, eps))
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
|
||||
Self {
|
||||
|
@ -14,8 +14,8 @@ const IMAGE_DIM: usize = 784;
|
||||
const LABELS: usize = 10;
|
||||
|
||||
fn linear_z(in_dim: usize, out_dim: usize, vs: VarBuilder) -> Result<Linear> {
|
||||
let ws = vs.get_or_init((out_dim, in_dim), "weight", candle_nn::init::ZERO)?;
|
||||
let bs = vs.get_or_init(out_dim, "bias", candle_nn::init::ZERO)?;
|
||||
let ws = vs.get_with_hints((out_dim, in_dim), "weight", candle_nn::init::ZERO)?;
|
||||
let bs = vs.get_with_hints(out_dim, "bias", candle_nn::init::ZERO)?;
|
||||
Ok(Linear::new(ws, Some(bs)))
|
||||
}
|
||||
|
||||
|
@ -368,7 +368,7 @@ impl<'a> Layer<'a> {
|
||||
self.cnt += 1;
|
||||
}
|
||||
|
||||
fn next(&mut self) -> VarBuilder<'a> {
|
||||
fn next(&mut self) -> VarBuilder {
|
||||
let vb = self.vb.pp(&self.cnt.to_string());
|
||||
self.cnt += 1;
|
||||
vb
|
||||
|
@ -179,11 +179,11 @@ pub fn batch_norm<C: Into<BatchNormConfig>>(
|
||||
if config.eps < 0. {
|
||||
candle::bail!("batch-norm eps cannot be negative {}", config.eps)
|
||||
}
|
||||
let running_mean = vb.get_or_init(num_features, "running_mean", crate::Init::Const(0.))?;
|
||||
let running_var = vb.get_or_init(num_features, "running_var", crate::Init::Const(1.))?;
|
||||
let running_mean = vb.get_with_hints(num_features, "running_mean", crate::Init::Const(0.))?;
|
||||
let running_var = vb.get_with_hints(num_features, "running_var", crate::Init::Const(1.))?;
|
||||
let weight_and_bias = if config.affine {
|
||||
let weight = vb.get_or_init(num_features, "weight", crate::Init::Const(1.))?;
|
||||
let bias = vb.get_or_init(num_features, "bias", crate::Init::Const(0.))?;
|
||||
let weight = vb.get_with_hints(num_features, "weight", crate::Init::Const(1.))?;
|
||||
let bias = vb.get_with_hints(num_features, "bias", crate::Init::Const(0.))?;
|
||||
Some((weight, bias))
|
||||
} else {
|
||||
None
|
||||
|
@ -124,7 +124,7 @@ pub fn conv1d(
|
||||
vs: crate::VarBuilder,
|
||||
) -> Result<Conv1d> {
|
||||
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
|
||||
let ws = vs.get_or_init(
|
||||
let ws = vs.get_with_hints(
|
||||
(out_channels, in_channels / cfg.groups, kernel_size),
|
||||
"weight",
|
||||
init_ws,
|
||||
@ -134,7 +134,7 @@ pub fn conv1d(
|
||||
lo: -bound,
|
||||
up: bound,
|
||||
};
|
||||
let bs = vs.get_or_init(out_channels, "bias", init_bs)?;
|
||||
let bs = vs.get_with_hints(out_channels, "bias", init_bs)?;
|
||||
Ok(Conv1d::new(ws, Some(bs), cfg))
|
||||
}
|
||||
|
||||
@ -146,7 +146,7 @@ pub fn conv2d(
|
||||
vs: crate::VarBuilder,
|
||||
) -> Result<Conv2d> {
|
||||
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
|
||||
let ws = vs.get_or_init(
|
||||
let ws = vs.get_with_hints(
|
||||
(
|
||||
out_channels,
|
||||
in_channels / cfg.groups,
|
||||
@ -161,7 +161,7 @@ pub fn conv2d(
|
||||
lo: -bound,
|
||||
up: bound,
|
||||
};
|
||||
let bs = vs.get_or_init(out_channels, "bias", init_bs)?;
|
||||
let bs = vs.get_with_hints(out_channels, "bias", init_bs)?;
|
||||
Ok(Conv2d::new(ws, Some(bs), cfg))
|
||||
}
|
||||
|
||||
@ -173,7 +173,7 @@ pub fn conv2d_no_bias(
|
||||
vs: crate::VarBuilder,
|
||||
) -> Result<Conv2d> {
|
||||
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
|
||||
let ws = vs.get_or_init(
|
||||
let ws = vs.get_with_hints(
|
||||
(
|
||||
out_channels,
|
||||
in_channels / cfg.groups,
|
||||
|
@ -32,7 +32,7 @@ impl crate::Module for Embedding {
|
||||
}
|
||||
|
||||
pub fn embedding(in_size: usize, out_size: usize, vb: crate::VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get_or_init(
|
||||
let embeddings = vb.get_with_hints(
|
||||
(in_size, out_size),
|
||||
"weight",
|
||||
crate::Init::Randn {
|
||||
|
@ -79,7 +79,7 @@ pub fn group_norm(
|
||||
eps: f64,
|
||||
vb: crate::VarBuilder,
|
||||
) -> Result<GroupNorm> {
|
||||
let weight = vb.get_or_init(num_channels, "weight", crate::Init::Const(1.))?;
|
||||
let bias = vb.get_or_init(num_channels, "bias", crate::Init::Const(0.))?;
|
||||
let weight = vb.get_with_hints(num_channels, "weight", crate::Init::Const(1.))?;
|
||||
let bias = vb.get_with_hints(num_channels, "bias", crate::Init::Const(0.))?;
|
||||
GroupNorm::new(weight, bias, num_channels, num_groups, eps)
|
||||
}
|
||||
|
@ -139,3 +139,9 @@ impl Init {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Init {
|
||||
fn default() -> Self {
|
||||
Self::Const(0.)
|
||||
}
|
||||
}
|
||||
|
@ -128,9 +128,9 @@ pub fn layer_norm<C: Into<LayerNormConfig>>(
|
||||
vb: crate::VarBuilder,
|
||||
) -> Result<LayerNorm> {
|
||||
let config = config.into();
|
||||
let weight = vb.get_or_init(size, "weight", crate::Init::Const(1.))?;
|
||||
let weight = vb.get_with_hints(size, "weight", crate::Init::Const(1.))?;
|
||||
let bias = if config.affine {
|
||||
Some(vb.get_or_init(size, "bias", crate::Init::Const(0.))?)
|
||||
Some(vb.get_with_hints(size, "bias", crate::Init::Const(0.))?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
@ -50,18 +50,18 @@ impl super::Module for Linear {
|
||||
/// This uses some default names for weight and biases, namely `"weight"` and `"bias"`.
|
||||
pub fn linear(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> {
|
||||
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
|
||||
let ws = vs.get_or_init((out_dim, in_dim), "weight", init_ws)?;
|
||||
let ws = vs.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
|
||||
let bound = 1. / (in_dim as f64).sqrt();
|
||||
let init_bs = crate::Init::Uniform {
|
||||
lo: -bound,
|
||||
up: bound,
|
||||
};
|
||||
let bs = vs.get_or_init(out_dim, "bias", init_bs)?;
|
||||
let bs = vs.get_with_hints(out_dim, "bias", init_bs)?;
|
||||
Ok(Linear::new(ws, Some(bs)))
|
||||
}
|
||||
|
||||
pub fn linear_no_bias(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> {
|
||||
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
|
||||
let ws = vs.get_or_init((out_dim, in_dim), "weight", init_ws)?;
|
||||
let ws = vs.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
|
||||
Ok(Linear::new(ws, None))
|
||||
}
|
||||
|
@ -2,139 +2,105 @@ use crate::VarMap;
|
||||
use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor};
|
||||
use safetensors::{slice::IndexOp, tensor::SafeTensors};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::rc::Rc;
|
||||
|
||||
// TODO: Maybe we would want the storage to be generic, e.g. with Box<dyn> to avoid too many
|
||||
// generics.
|
||||
enum Tensors<'a> {
|
||||
SafeTensorWithRouting {
|
||||
routing: HashMap<String, usize>,
|
||||
safetensors: Vec<SafeTensors<'a>>,
|
||||
},
|
||||
Npz(candle::npy::NpzTensors),
|
||||
TensorMap(HashMap<String, Tensor>),
|
||||
Zeros,
|
||||
VarMap(VarMap),
|
||||
/// A structure used to retrieve variables, these variables can either come from storage or be
|
||||
/// generated via some form of initialization.
|
||||
///
|
||||
/// The way to retrieve variables is defined in the backend embedded in the `VarBuilder`.
|
||||
pub struct VarBuilderArgs<'a, B: Backend> {
|
||||
data: Rc<TensorData<B>>,
|
||||
path: Vec<String>,
|
||||
_phantom: std::marker::PhantomData<&'a B>,
|
||||
}
|
||||
|
||||
struct TensorData<'a> {
|
||||
tensors: Tensors<'a>,
|
||||
impl<'a, B: Backend> Clone for VarBuilderArgs<'a, B> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
data: self.data.clone(),
|
||||
path: self.path.clone(),
|
||||
_phantom: self._phantom,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A simple `VarBuilder`, this is less generic than `VarBuilderArgs` but should cover most common
|
||||
/// use cases.
|
||||
pub type VarBuilder<'a> = VarBuilderArgs<'a, Box<dyn SimpleBackend + 'a>>;
|
||||
|
||||
struct TensorData<B: Backend> {
|
||||
backend: B,
|
||||
pub dtype: DType,
|
||||
pub device: Device,
|
||||
}
|
||||
|
||||
impl<'a> TensorData<'a> {
|
||||
fn from_safetensors(safetensors: Vec<SafeTensors<'a>>, dtype: DType, device: &Device) -> Self {
|
||||
let mut routing = HashMap::new();
|
||||
for (index, sf) in safetensors.iter().enumerate() {
|
||||
for k in sf.names() {
|
||||
routing.insert(k.to_string(), index);
|
||||
}
|
||||
}
|
||||
let tensors = Tensors::SafeTensorWithRouting {
|
||||
routing,
|
||||
safetensors,
|
||||
/// A trait that defines how tensor data is retrieved.
|
||||
///
|
||||
/// Typically this would use disk storage in some specific format, or random initialization.
|
||||
/// Note that there is a speciliazed version of this trait (`SimpleBackend`) that can be used most
|
||||
/// of the time. The main restriction is that it doesn't allow for specific args (besides
|
||||
/// initialization hints).
|
||||
pub trait Backend {
|
||||
type Hints: Default;
|
||||
|
||||
/// Retrieve a tensor with some target shape.
|
||||
fn get(
|
||||
&self,
|
||||
s: Shape,
|
||||
name: &str,
|
||||
h: Self::Hints,
|
||||
dtype: DType,
|
||||
dev: &Device,
|
||||
) -> Result<Tensor>;
|
||||
}
|
||||
|
||||
pub trait SimpleBackend {
|
||||
/// Retrieve a tensor based on a target name and shape.
|
||||
fn get(
|
||||
&self,
|
||||
s: Shape,
|
||||
name: &str,
|
||||
h: crate::Init,
|
||||
dtype: DType,
|
||||
dev: &Device,
|
||||
) -> Result<Tensor>;
|
||||
}
|
||||
|
||||
impl<'a> Backend for Box<dyn SimpleBackend + 'a> {
|
||||
type Hints = crate::Init;
|
||||
fn get(
|
||||
&self,
|
||||
s: Shape,
|
||||
name: &str,
|
||||
h: Self::Hints,
|
||||
dtype: DType,
|
||||
dev: &Device,
|
||||
) -> Result<Tensor> {
|
||||
self.as_ref().get(s, name, h, dtype, dev)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, B: Backend> VarBuilderArgs<'a, B> {
|
||||
pub fn new_with_args(backend: B, dtype: DType, dev: &Device) -> Self {
|
||||
let data = TensorData {
|
||||
backend,
|
||||
dtype,
|
||||
device: dev.clone(),
|
||||
};
|
||||
Self {
|
||||
tensors,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
|
||||
fn zeros(dtype: DType, device: &Device) -> Self {
|
||||
Self {
|
||||
tensors: Tensors::Zeros,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
|
||||
fn from_tensors(tensors: HashMap<String, Tensor>, dtype: DType, device: &Device) -> Self {
|
||||
Self {
|
||||
tensors: Tensors::TensorMap(tensors),
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
|
||||
fn from_npz<P: AsRef<std::path::Path>>(file: P, dtype: DType, device: &Device) -> Result<Self> {
|
||||
let npz = candle::npy::NpzTensors::new(file)?;
|
||||
Ok(Self {
|
||||
tensors: Tensors::Npz(npz),
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
fn from_varmap(varmap: &VarMap, dtype: DType, device: &Device) -> Self {
|
||||
Self {
|
||||
tensors: Tensors::VarMap(varmap.clone()),
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct VarBuilder<'a> {
|
||||
data: Arc<TensorData<'a>>,
|
||||
path: Vec<String>,
|
||||
}
|
||||
|
||||
impl<'a> VarBuilder<'a> {
|
||||
/// Create a `VarBuilder` accessing data frome the safetensors storage. The initial path is
|
||||
/// set to the root path and sub-paths can be created via the `push_prefix` method.
|
||||
pub fn from_safetensors(st: Vec<SafeTensors<'a>>, dtype: DType, device: &Device) -> Self {
|
||||
let data = TensorData::from_safetensors(st, dtype, device);
|
||||
Self {
|
||||
data: Arc::new(data),
|
||||
data: Rc::new(data),
|
||||
path: vec![],
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn zeros(dtype: DType, device: &Device) -> Self {
|
||||
let data = TensorData::zeros(dtype, device);
|
||||
Self {
|
||||
data: Arc::new(data),
|
||||
path: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_tensors(ts: HashMap<String, Tensor>, dtype: DType, device: &Device) -> Self {
|
||||
let data = TensorData::from_tensors(ts, dtype, device);
|
||||
Self {
|
||||
data: Arc::new(data),
|
||||
path: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_varmap(varmap: &VarMap, dtype: DType, device: &Device) -> Self {
|
||||
let data = TensorData::from_varmap(varmap, dtype, device);
|
||||
Self {
|
||||
data: Arc::new(data),
|
||||
path: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_npz<P: AsRef<std::path::Path>>(
|
||||
file: P,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
let data = TensorData::from_npz(file, dtype, device)?;
|
||||
Ok(Self {
|
||||
data: Arc::new(data),
|
||||
path: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
pub fn push_prefix<S: ToString>(&self, s: S) -> Self {
|
||||
let mut path = self.path.clone();
|
||||
path.push(s.to_string());
|
||||
Self {
|
||||
data: self.data.clone(),
|
||||
path,
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
@ -150,130 +116,108 @@ impl<'a> VarBuilder<'a> {
|
||||
pub fn dtype(&self) -> DType {
|
||||
self.data.dtype
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> VarBuilder<'a> {
|
||||
/// Get part of a tensor, typically used to do Tensor Parallelism sharding.
|
||||
///
|
||||
/// If the tensor is of size (1024, 1024).
|
||||
///
|
||||
/// `dim` corresponds to the dimension to slice into
|
||||
/// `rank` is the rank of the current process
|
||||
/// `world_size` is the total number of ranks in the process group
|
||||
///
|
||||
/// `get_sharded("tensor", 0, 0, 2)` means `tensor.i((..512))`
|
||||
/// `get_sharded("tensor", 0, 1, 2)` means `tensor.i((512..))`
|
||||
/// `get_sharded("tensor", 1, 0, 2)` means `tensor.i((.., ..512))`
|
||||
pub fn get_sharded(
|
||||
&self,
|
||||
tensor_name: &str,
|
||||
dim: usize,
|
||||
rank: usize,
|
||||
world_size: usize,
|
||||
) -> Result<Tensor> {
|
||||
let data = self.data.as_ref();
|
||||
let path = self.path(tensor_name);
|
||||
let tensor = match &self.data.tensors {
|
||||
Tensors::SafeTensorWithRouting {
|
||||
routing,
|
||||
safetensors,
|
||||
} => {
|
||||
let index = routing.get(&path).ok_or_else(|| {
|
||||
Error::CannotFindTensor {
|
||||
path: path.to_string(),
|
||||
}
|
||||
.bt()
|
||||
})?;
|
||||
|
||||
let view = safetensors[*index].tensor(&path)?;
|
||||
let dtype = view.dtype();
|
||||
let mut shape = view.shape().to_vec();
|
||||
let size = shape[dim];
|
||||
|
||||
if size % world_size != 0 {
|
||||
return Err(Error::ShapeMismatchSplit {
|
||||
shape: shape.into(),
|
||||
dim,
|
||||
n_parts: world_size,
|
||||
});
|
||||
}
|
||||
let block_size = size / world_size;
|
||||
let start = rank * block_size;
|
||||
let stop = (rank + 1) * block_size;
|
||||
|
||||
// Everything is expressed in tensor dimension
|
||||
// bytes offsets is handled automatically for safetensors.
|
||||
|
||||
let iterator = if dim == 0 {
|
||||
view.slice(start..stop).map_err(|_| Error::Msg(format!("Cannot slice tensor {tensor_name} ({shape:?} along dim {dim} with {start}..{stop}")))?
|
||||
} else if dim == 1 {
|
||||
view.slice((.., start..stop)).map_err(|_| Error::Msg(format!("Cannot slice tensor {tensor_name} ({shape:?} along dim {dim} with {start}..{stop}")))?
|
||||
} else {
|
||||
candle::bail!("Get sharded on dimensions != 0 or 1")
|
||||
};
|
||||
|
||||
shape[dim] = block_size;
|
||||
|
||||
let dtype: DType = dtype.try_into()?;
|
||||
|
||||
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
|
||||
Tensor::from_raw_buffer(&raw, dtype, &shape, &data.device)?
|
||||
}
|
||||
_ => candle::bail!("get_sharded is only available for safetensors"),
|
||||
};
|
||||
Ok(tensor)
|
||||
fn path(&self, tensor_name: &str) -> String {
|
||||
if self.path.is_empty() {
|
||||
tensor_name.to_string()
|
||||
} else {
|
||||
[&self.path.join("."), tensor_name].join(".")
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieve the tensor associated with the given name at the current path.
|
||||
pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> Result<Tensor> {
|
||||
let data = self.data.as_ref();
|
||||
let s: Shape = s.into();
|
||||
let path = self.path(tensor_name);
|
||||
let tensor = match &self.data.tensors {
|
||||
Tensors::Zeros => Tensor::zeros(&s, data.dtype, &data.device)?.contiguous()?,
|
||||
Tensors::TensorMap(ts) => ts
|
||||
.get(&path)
|
||||
.ok_or_else(|| {
|
||||
Error::CannotFindTensor {
|
||||
path: path.to_string(),
|
||||
}
|
||||
.bt()
|
||||
})?
|
||||
.clone(),
|
||||
Tensors::VarMap(varmap) => {
|
||||
let data = varmap.data().lock().unwrap();
|
||||
data.get(&path)
|
||||
.ok_or_else(|| {
|
||||
Error::CannotFindTensor {
|
||||
path: path.to_string(),
|
||||
}
|
||||
.bt()
|
||||
})?
|
||||
.as_tensor()
|
||||
.clone()
|
||||
}
|
||||
Tensors::Npz(npz) => npz.get(&path)?.ok_or_else(|| {
|
||||
pub fn get_with_hints<S: Into<Shape>>(
|
||||
&self,
|
||||
s: S,
|
||||
name: &str,
|
||||
hints: B::Hints,
|
||||
) -> Result<Tensor> {
|
||||
let path = self.path(name);
|
||||
self.data
|
||||
.backend
|
||||
.get(s.into(), &path, hints, self.data.dtype, &self.data.device)
|
||||
}
|
||||
|
||||
/// Retrieve the tensor associated with the given name at the current path.
|
||||
pub fn get<S: Into<Shape>>(&self, s: S, name: &str) -> Result<Tensor> {
|
||||
self.get_with_hints(s, name, Default::default())
|
||||
}
|
||||
}
|
||||
|
||||
struct Zeros;
|
||||
impl SimpleBackend for Zeros {
|
||||
fn get(&self, s: Shape, _: &str, _: crate::Init, dtype: DType, dev: &Device) -> Result<Tensor> {
|
||||
Tensor::zeros(s, dtype, dev)
|
||||
}
|
||||
}
|
||||
|
||||
impl SimpleBackend for HashMap<String, Tensor> {
|
||||
fn get(
|
||||
&self,
|
||||
s: Shape,
|
||||
name: &str,
|
||||
_: crate::Init,
|
||||
dtype: DType,
|
||||
dev: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let tensor = self
|
||||
.get(name)
|
||||
.ok_or_else(|| {
|
||||
Error::CannotFindTensor {
|
||||
path: path.to_string(),
|
||||
path: name.to_string(),
|
||||
}
|
||||
.bt()
|
||||
})?,
|
||||
Tensors::SafeTensorWithRouting {
|
||||
routing,
|
||||
safetensors,
|
||||
} => {
|
||||
let index = routing.get(&path).ok_or_else(|| {
|
||||
Error::CannotFindTensor {
|
||||
path: path.to_string(),
|
||||
}
|
||||
.bt()
|
||||
})?;
|
||||
safetensors[*index]
|
||||
.tensor(&path)?
|
||||
.load(&data.device)?
|
||||
.to_dtype(data.dtype)?
|
||||
})?
|
||||
.clone();
|
||||
if tensor.shape() != &s {
|
||||
Err(candle::Error::UnexpectedShape {
|
||||
msg: format!("shape mismatch for {name}"),
|
||||
expected: s,
|
||||
got: tensor.shape().clone(),
|
||||
}
|
||||
};
|
||||
.bt())?
|
||||
}
|
||||
tensor.to_device(dev)?.to_dtype(dtype)
|
||||
}
|
||||
}
|
||||
|
||||
impl SimpleBackend for VarMap {
|
||||
fn get(
|
||||
&self,
|
||||
s: Shape,
|
||||
name: &str,
|
||||
h: crate::Init,
|
||||
dtype: DType,
|
||||
dev: &Device,
|
||||
) -> Result<Tensor> {
|
||||
VarMap::get(self, s, name, h, dtype, dev)
|
||||
}
|
||||
}
|
||||
|
||||
struct SafeTensorWithRouting<'a> {
|
||||
routing: HashMap<String, usize>,
|
||||
safetensors: Vec<SafeTensors<'a>>,
|
||||
}
|
||||
|
||||
impl<'a> SimpleBackend for SafeTensorWithRouting<'a> {
|
||||
fn get(
|
||||
&self,
|
||||
s: Shape,
|
||||
path: &str,
|
||||
_: crate::Init,
|
||||
dtype: DType,
|
||||
dev: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let index = self.routing.get(path).ok_or_else(|| {
|
||||
Error::CannotFindTensor {
|
||||
path: path.to_string(),
|
||||
}
|
||||
.bt()
|
||||
})?;
|
||||
let tensor = self.safetensors[*index]
|
||||
.tensor(path)?
|
||||
.load(dev)?
|
||||
.to_dtype(dtype)?;
|
||||
if tensor.shape() != &s {
|
||||
Err(candle::Error::UnexpectedShape {
|
||||
msg: format!("shape mismatch for {path}"),
|
||||
@ -284,32 +228,201 @@ impl<'a> VarBuilder<'a> {
|
||||
}
|
||||
Ok(tensor)
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieve the tensor associated with the given name at the current path or initialize a new
|
||||
/// tensor if it's missing.
|
||||
///
|
||||
/// Tensor initialization is only available if the `VarBuilder` is backed by a `VarMap`.
|
||||
pub fn get_or_init<S: Into<Shape>>(
|
||||
impl SimpleBackend for candle::npy::NpzTensors {
|
||||
fn get(
|
||||
&self,
|
||||
s: S,
|
||||
tensor_name: &str,
|
||||
init: crate::Init,
|
||||
s: Shape,
|
||||
path: &str,
|
||||
_: crate::Init,
|
||||
dtype: DType,
|
||||
dev: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let data = self.data.as_ref();
|
||||
match &self.data.tensors {
|
||||
Tensors::VarMap(varmap) => {
|
||||
let path = self.path(tensor_name);
|
||||
varmap.get(s, &path, init, data.dtype, &data.device)
|
||||
let tensor = match self.get(path)? {
|
||||
None => Err(Error::CannotFindTensor {
|
||||
path: path.to_string(),
|
||||
}
|
||||
_ => self.get(s, tensor_name),
|
||||
.bt())?,
|
||||
Some(tensor) => tensor,
|
||||
};
|
||||
let tensor = tensor.to_device(dev)?.to_dtype(dtype)?;
|
||||
if tensor.shape() != &s {
|
||||
Err(candle::Error::UnexpectedShape {
|
||||
msg: format!("shape mismatch for {path}"),
|
||||
expected: s,
|
||||
got: tensor.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
Ok(tensor)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> VarBuilder<'a> {
|
||||
fn new(backend: Box<dyn SimpleBackend + 'a>, dtype: DType, device: Device) -> Self {
|
||||
let data = TensorData {
|
||||
backend,
|
||||
dtype,
|
||||
device,
|
||||
};
|
||||
Self {
|
||||
data: Rc::new(data),
|
||||
path: vec![],
|
||||
_phantom: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn path(&self, tensor_name: &str) -> String {
|
||||
if self.path.is_empty() {
|
||||
tensor_name.to_string()
|
||||
} else {
|
||||
[&self.path.join("."), tensor_name].join(".")
|
||||
pub fn zeros(dtype: DType, dev: &Device) -> Self {
|
||||
Self::new(Box::new(Zeros), dtype, dev.clone())
|
||||
}
|
||||
|
||||
pub fn from_tensors(ts: HashMap<String, Tensor>, dtype: DType, dev: &Device) -> Self {
|
||||
Self::new(Box::new(ts), dtype, dev.clone())
|
||||
}
|
||||
|
||||
pub fn from_varmap(varmap: &VarMap, dtype: DType, dev: &Device) -> Self {
|
||||
Self::new(Box::new(varmap.clone()), dtype, dev.clone())
|
||||
}
|
||||
|
||||
pub fn from_safetensors(safetensors: Vec<SafeTensors<'a>>, dtype: DType, dev: &Device) -> Self {
|
||||
let mut routing = HashMap::new();
|
||||
for (index, sf) in safetensors.iter().enumerate() {
|
||||
for k in sf.names() {
|
||||
routing.insert(k.to_string(), index);
|
||||
}
|
||||
}
|
||||
let tensors = SafeTensorWithRouting {
|
||||
routing,
|
||||
safetensors,
|
||||
};
|
||||
Self::new(Box::new(tensors), dtype, dev.clone())
|
||||
}
|
||||
|
||||
pub fn from_npz<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
|
||||
let npz = candle::npy::NpzTensors::new(p)?;
|
||||
Ok(Self::new(Box::new(npz), dtype, dev.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ShardedSafeTensors<'a>(SafeTensorWithRouting<'a>);
|
||||
pub type ShardedVarBuilder<'a> = VarBuilderArgs<'a, ShardedSafeTensors<'a>>;
|
||||
|
||||
impl<'a> ShardedSafeTensors<'a> {
|
||||
pub fn var_builder(
|
||||
safetensors: Vec<SafeTensors<'a>>,
|
||||
dtype: DType,
|
||||
dev: &Device,
|
||||
) -> ShardedVarBuilder<'a> {
|
||||
let mut routing = HashMap::new();
|
||||
for (index, sf) in safetensors.iter().enumerate() {
|
||||
for k in sf.names() {
|
||||
routing.insert(k.to_string(), index);
|
||||
}
|
||||
}
|
||||
let tensors = SafeTensorWithRouting {
|
||||
routing,
|
||||
safetensors,
|
||||
};
|
||||
let backend = ShardedSafeTensors(tensors);
|
||||
VarBuilderArgs::new_with_args(backend, dtype, dev)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
|
||||
pub struct Shard {
|
||||
pub dim: usize,
|
||||
pub rank: usize,
|
||||
pub world_size: usize,
|
||||
}
|
||||
|
||||
impl Default for Shard {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
dim: 0,
|
||||
rank: 0,
|
||||
world_size: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get part of a tensor, typically used to do Tensor Parallelism sharding.
|
||||
///
|
||||
/// If the tensor is of size (1024, 1024).
|
||||
///
|
||||
/// `dim` corresponds to the dimension to slice into
|
||||
/// `rank` is the rank of the current process
|
||||
/// `world_size` is the total number of ranks in the process group
|
||||
///
|
||||
/// `get_sharded("tensor", 0, 0, 2)` means `tensor.i((..512))`
|
||||
/// `get_sharded("tensor", 0, 1, 2)` means `tensor.i((512..))`
|
||||
/// `get_sharded("tensor", 1, 0, 2)` means `tensor.i((.., ..512))`
|
||||
impl<'a> Backend for ShardedSafeTensors<'a> {
|
||||
type Hints = Shard;
|
||||
|
||||
fn get(
|
||||
&self,
|
||||
_target_shape: Shape, // The size is not checked for ShardedTensors
|
||||
path: &str,
|
||||
h: Self::Hints,
|
||||
dtype: DType,
|
||||
dev: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let Shard {
|
||||
dim,
|
||||
rank,
|
||||
world_size,
|
||||
} = h;
|
||||
let SafeTensorWithRouting {
|
||||
routing,
|
||||
safetensors,
|
||||
} = &self.0;
|
||||
let index = routing.get(path).ok_or_else(|| {
|
||||
Error::CannotFindTensor {
|
||||
path: path.to_string(),
|
||||
}
|
||||
.bt()
|
||||
})?;
|
||||
|
||||
let view = safetensors[*index].tensor(path)?;
|
||||
let view_dtype = view.dtype();
|
||||
let mut shape = view.shape().to_vec();
|
||||
let size = shape[dim];
|
||||
|
||||
if size % world_size != 0 {
|
||||
return Err(Error::ShapeMismatchSplit {
|
||||
shape: shape.into(),
|
||||
dim,
|
||||
n_parts: world_size,
|
||||
});
|
||||
}
|
||||
let block_size = size / world_size;
|
||||
let start = rank * block_size;
|
||||
let stop = (rank + 1) * block_size;
|
||||
|
||||
// Everything is expressed in tensor dimension
|
||||
// bytes offsets is handled automatically for safetensors.
|
||||
|
||||
let iterator = if dim == 0 {
|
||||
view.slice(start..stop).map_err(|_| {
|
||||
Error::Msg(format!(
|
||||
"Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
|
||||
))
|
||||
})?
|
||||
} else if dim == 1 {
|
||||
view.slice((.., start..stop)).map_err(|_| {
|
||||
Error::Msg(format!(
|
||||
"Cannot slice tensor {path} ({shape:?} along dim {dim} with {start}..{stop}"
|
||||
))
|
||||
})?
|
||||
} else {
|
||||
candle::bail!("Get sharded on dimensions != 0 or 1")
|
||||
};
|
||||
|
||||
shape[dim] = block_size;
|
||||
|
||||
let view_dtype: DType = view_dtype.try_into()?;
|
||||
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
|
||||
Tensor::from_raw_buffer(&raw, view_dtype, &shape, dev)?.to_dtype(dtype)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user