TP sharding v2

This commit is contained in:
Nicolas Patry
2023-07-21 15:10:51 +00:00
parent 209f06d7c3
commit 1735e4831e
9 changed files with 833 additions and 18 deletions

View File

@ -1,7 +1,6 @@
use candle::{
safetensors::{Load, SafeTensors},
DType, Device, Error, Result, Shape, Tensor,
};
use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor};
use safetensors::slice::IndexOp;
use safetensors::tensor::SafeTensors;
use std::collections::HashMap;
use std::sync::Arc;
@ -71,7 +70,7 @@ impl<'a> TensorData<'a> {
#[derive(Clone)]
pub struct VarBuilder<'a> {
data: Arc<TensorData<'a>>,
path: Vec<String>,
pub path: Vec<String>,
}
impl<'a> VarBuilder<'a> {
@ -137,6 +136,55 @@ impl<'a> VarBuilder<'a> {
}
impl<'a> VarBuilder<'a> {
pub fn get_sharded(
&self,
tensor_name: &str,
dim: usize,
rank: usize,
world_size: usize,
) -> Result<Tensor> {
let data = self.data.as_ref();
let path = if self.path.is_empty() {
tensor_name.to_string()
} else {
[&self.path.join("."), tensor_name].join(".")
};
let tensor = match &self.data.tensors {
Tensors::SafeTensorWithRouting {
routing,
safetensors,
} => {
let index = routing.get(&path).ok_or_else(|| {
Error::CannotFindTensor {
path: path.to_string(),
}
.bt()
})?;
let view = safetensors[*index].tensor(&path)?;
let dtype = view.dtype();
let mut shape = view.shape().to_vec();
let size = shape[dim];
let block_size = size / world_size;
let start = rank * block_size;
let stop = (rank + 1) * block_size;
let iterator = if dim == 0 {
view.slice(start..stop).unwrap()
} else if dim == 1 {
view.slice((.., start..stop)).unwrap()
} else {
unimplemented!("Get sharded on dimensions != 0 or 1");
};
shape[dim] = block_size;
Tensor::from_safetensors_slice(iterator, dtype, &shape, &data.device)?
}
_ => unimplemented!(),
};
Ok(tensor)
}
pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> Result<Tensor> {
let data = self.data.as_ref();
let s: Shape = s.into();