From 952eca6b540078b1f30b58d9eb930f8e32d903cb Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 27 Jul 2023 16:59:32 +0200 Subject: [PATCH] Fixing slice errors + comments. --- candle-core/src/error.rs | 7 +++++++ candle-nn/src/var_builder.rs | 25 ++++++++++++++++++++++--- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index f9e69122..30d06239 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -79,6 +79,13 @@ pub enum Error { nth_shape: Shape, }, + #[error("Cannot divide tensor of shape {shape:?} equally along dim {dim} into {n_parts}")] + ShapeMismatchSplit { + shape: Shape, + dim: usize, + n_parts: usize, + }, + #[error("{op} can only be performed on a single dimension")] OnlySingleDimension { op: &'static str, dims: Vec }, diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 1466f6d0..3133f210 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -135,6 +135,17 @@ impl<'a> VarBuilder<'a> { } impl<'a> VarBuilder<'a> { + /// Get part of a tensor, typically used to do Tensor Parallelism sharding. + /// + /// If the tensor is of size (1024, 1024). + /// + /// `dim` corresponds to the dimension to slice into + /// `rank` is the rank of the current process + /// `world_size` is the total number of ranks in the process group + /// + /// `get_sharded("tensor", 0, 0, 2)` means `tensor.i((..512))` + /// `get_sharded("tensor", 0, 1, 2)` means `tensor.i((512..))` + /// `get_sharded("tensor", 1, 0, 2)` means `tensor.i((.., ..512))` pub fn get_sharded( &self, tensor_name: &str, @@ -164,16 +175,24 @@ impl<'a> VarBuilder<'a> { let dtype = view.dtype(); let mut shape = view.shape().to_vec(); let size = shape[dim]; + + if size % world_size != 0 { + return Err(Error::ShapeMismatchSplit { + shape: shape.into(), + dim, + n_parts: world_size, + }); + } let block_size = size / world_size; let start = rank * block_size; let stop = (rank + 1) * block_size; let iterator = if dim == 0 { - view.slice(start..stop).unwrap() + view.slice(start..stop).map_err(|_| Error::Msg(format!("Cannot slice tensor {tensor_name} ({shape:?} along dim {dim} with {start}..{stop}")))? } else if dim == 1 { - view.slice((.., start..stop)).unwrap() + view.slice((.., start..stop)).map_err(|_| Error::Msg(format!("Cannot slice tensor {tensor_name} ({shape:?} along dim {dim} with {start}..{stop}")))? } else { - unimplemented!("Get sharded on dimensions != 0 or 1"); + candle::bail!("Get sharded on dimensions != 0 or 1") }; shape[dim] = block_size;