mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Fixing slice errors + comments.
This commit is contained in:
@ -135,6 +135,17 @@ impl<'a> VarBuilder<'a> {
|
||||
}
|
||||
|
||||
impl<'a> VarBuilder<'a> {
|
||||
/// Get part of a tensor, typically used to do Tensor Parallelism sharding.
|
||||
///
|
||||
/// If the tensor is of size (1024, 1024).
|
||||
///
|
||||
/// `dim` corresponds to the dimension to slice into
|
||||
/// `rank` is the rank of the current process
|
||||
/// `world_size` is the total number of ranks in the process group
|
||||
///
|
||||
/// `get_sharded("tensor", 0, 0, 2)` means `tensor.i((..512))`
|
||||
/// `get_sharded("tensor", 0, 1, 2)` means `tensor.i((512..))`
|
||||
/// `get_sharded("tensor", 1, 0, 2)` means `tensor.i((.., ..512))`
|
||||
pub fn get_sharded(
|
||||
&self,
|
||||
tensor_name: &str,
|
||||
@ -164,16 +175,24 @@ impl<'a> VarBuilder<'a> {
|
||||
let dtype = view.dtype();
|
||||
let mut shape = view.shape().to_vec();
|
||||
let size = shape[dim];
|
||||
|
||||
if size % world_size != 0 {
|
||||
return Err(Error::ShapeMismatchSplit {
|
||||
shape: shape.into(),
|
||||
dim,
|
||||
n_parts: world_size,
|
||||
});
|
||||
}
|
||||
let block_size = size / world_size;
|
||||
let start = rank * block_size;
|
||||
let stop = (rank + 1) * block_size;
|
||||
|
||||
let iterator = if dim == 0 {
|
||||
view.slice(start..stop).unwrap()
|
||||
view.slice(start..stop).map_err(|_| Error::Msg(format!("Cannot slice tensor {tensor_name} ({shape:?} along dim {dim} with {start}..{stop}")))?
|
||||
} else if dim == 1 {
|
||||
view.slice((.., start..stop)).unwrap()
|
||||
view.slice((.., start..stop)).map_err(|_| Error::Msg(format!("Cannot slice tensor {tensor_name} ({shape:?} along dim {dim} with {start}..{stop}")))?
|
||||
} else {
|
||||
unimplemented!("Get sharded on dimensions != 0 or 1");
|
||||
candle::bail!("Get sharded on dimensions != 0 or 1")
|
||||
};
|
||||
|
||||
shape[dim] = block_size;
|
||||
|
Reference in New Issue
Block a user