Fixing slice errors + comments.

This commit is contained in:
Nicolas Patry
2023-07-27 16:59:32 +02:00
parent 25a2086e8f
commit 952eca6b54
2 changed files with 29 additions and 3 deletions

View File

@ -135,6 +135,17 @@ impl<'a> VarBuilder<'a> {
}
impl<'a> VarBuilder<'a> {
/// Get part of a tensor, typically used to do Tensor Parallelism sharding.
///
/// If the tensor is of size (1024, 1024).
///
/// `dim` corresponds to the dimension to slice into
/// `rank` is the rank of the current process
/// `world_size` is the total number of ranks in the process group
///
/// `get_sharded("tensor", 0, 0, 2)` means `tensor.i((..512))`
/// `get_sharded("tensor", 0, 1, 2)` means `tensor.i((512..))`
/// `get_sharded("tensor", 1, 0, 2)` means `tensor.i((.., ..512))`
pub fn get_sharded(
&self,
tensor_name: &str,
@ -164,16 +175,24 @@ impl<'a> VarBuilder<'a> {
let dtype = view.dtype();
let mut shape = view.shape().to_vec();
let size = shape[dim];
if size % world_size != 0 {
return Err(Error::ShapeMismatchSplit {
shape: shape.into(),
dim,
n_parts: world_size,
});
}
let block_size = size / world_size;
let start = rank * block_size;
let stop = (rank + 1) * block_size;
let iterator = if dim == 0 {
view.slice(start..stop).unwrap()
view.slice(start..stop).map_err(|_| Error::Msg(format!("Cannot slice tensor {tensor_name} ({shape:?} along dim {dim} with {start}..{stop}")))?
} else if dim == 1 {
view.slice((.., start..stop)).unwrap()
view.slice((.., start..stop)).map_err(|_| Error::Msg(format!("Cannot slice tensor {tensor_name} ({shape:?} along dim {dim} with {start}..{stop}")))?
} else {
unimplemented!("Get sharded on dimensions != 0 or 1");
candle::bail!("Get sharded on dimensions != 0 or 1")
};
shape[dim] = block_size;