mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Fixing slice errors + comments.
This commit is contained in:
@ -79,6 +79,13 @@ pub enum Error {
|
|||||||
nth_shape: Shape,
|
nth_shape: Shape,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
#[error("Cannot divide tensor of shape {shape:?} equally along dim {dim} into {n_parts}")]
|
||||||
|
ShapeMismatchSplit {
|
||||||
|
shape: Shape,
|
||||||
|
dim: usize,
|
||||||
|
n_parts: usize,
|
||||||
|
},
|
||||||
|
|
||||||
#[error("{op} can only be performed on a single dimension")]
|
#[error("{op} can only be performed on a single dimension")]
|
||||||
OnlySingleDimension { op: &'static str, dims: Vec<usize> },
|
OnlySingleDimension { op: &'static str, dims: Vec<usize> },
|
||||||
|
|
||||||
|
@ -135,6 +135,17 @@ impl<'a> VarBuilder<'a> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> VarBuilder<'a> {
|
impl<'a> VarBuilder<'a> {
|
||||||
|
/// Get part of a tensor, typically used to do Tensor Parallelism sharding.
|
||||||
|
///
|
||||||
|
/// If the tensor is of size (1024, 1024).
|
||||||
|
///
|
||||||
|
/// `dim` corresponds to the dimension to slice into
|
||||||
|
/// `rank` is the rank of the current process
|
||||||
|
/// `world_size` is the total number of ranks in the process group
|
||||||
|
///
|
||||||
|
/// `get_sharded("tensor", 0, 0, 2)` means `tensor.i((..512))`
|
||||||
|
/// `get_sharded("tensor", 0, 1, 2)` means `tensor.i((512..))`
|
||||||
|
/// `get_sharded("tensor", 1, 0, 2)` means `tensor.i((.., ..512))`
|
||||||
pub fn get_sharded(
|
pub fn get_sharded(
|
||||||
&self,
|
&self,
|
||||||
tensor_name: &str,
|
tensor_name: &str,
|
||||||
@ -164,16 +175,24 @@ impl<'a> VarBuilder<'a> {
|
|||||||
let dtype = view.dtype();
|
let dtype = view.dtype();
|
||||||
let mut shape = view.shape().to_vec();
|
let mut shape = view.shape().to_vec();
|
||||||
let size = shape[dim];
|
let size = shape[dim];
|
||||||
|
|
||||||
|
if size % world_size != 0 {
|
||||||
|
return Err(Error::ShapeMismatchSplit {
|
||||||
|
shape: shape.into(),
|
||||||
|
dim,
|
||||||
|
n_parts: world_size,
|
||||||
|
});
|
||||||
|
}
|
||||||
let block_size = size / world_size;
|
let block_size = size / world_size;
|
||||||
let start = rank * block_size;
|
let start = rank * block_size;
|
||||||
let stop = (rank + 1) * block_size;
|
let stop = (rank + 1) * block_size;
|
||||||
|
|
||||||
let iterator = if dim == 0 {
|
let iterator = if dim == 0 {
|
||||||
view.slice(start..stop).unwrap()
|
view.slice(start..stop).map_err(|_| Error::Msg(format!("Cannot slice tensor {tensor_name} ({shape:?} along dim {dim} with {start}..{stop}")))?
|
||||||
} else if dim == 1 {
|
} else if dim == 1 {
|
||||||
view.slice((.., start..stop)).unwrap()
|
view.slice((.., start..stop)).map_err(|_| Error::Msg(format!("Cannot slice tensor {tensor_name} ({shape:?} along dim {dim} with {start}..{stop}")))?
|
||||||
} else {
|
} else {
|
||||||
unimplemented!("Get sharded on dimensions != 0 or 1");
|
candle::bail!("Get sharded on dimensions != 0 or 1")
|
||||||
};
|
};
|
||||||
|
|
||||||
shape[dim] = block_size;
|
shape[dim] = block_size;
|
||||||
|
Reference in New Issue
Block a user