mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
TP sharding v2
This commit is contained in:
@ -1,7 +1,6 @@
|
||||
use candle::{
|
||||
safetensors::{Load, SafeTensors},
|
||||
DType, Device, Error, Result, Shape, Tensor,
|
||||
};
|
||||
use candle::{safetensors::Load, DType, Device, Error, Result, Shape, Tensor};
|
||||
use safetensors::slice::IndexOp;
|
||||
use safetensors::tensor::SafeTensors;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
@ -71,7 +70,7 @@ impl<'a> TensorData<'a> {
|
||||
#[derive(Clone)]
|
||||
pub struct VarBuilder<'a> {
|
||||
data: Arc<TensorData<'a>>,
|
||||
path: Vec<String>,
|
||||
pub path: Vec<String>,
|
||||
}
|
||||
|
||||
impl<'a> VarBuilder<'a> {
|
||||
@ -137,6 +136,55 @@ impl<'a> VarBuilder<'a> {
|
||||
}
|
||||
|
||||
impl<'a> VarBuilder<'a> {
|
||||
pub fn get_sharded(
|
||||
&self,
|
||||
tensor_name: &str,
|
||||
dim: usize,
|
||||
rank: usize,
|
||||
world_size: usize,
|
||||
) -> Result<Tensor> {
|
||||
let data = self.data.as_ref();
|
||||
let path = if self.path.is_empty() {
|
||||
tensor_name.to_string()
|
||||
} else {
|
||||
[&self.path.join("."), tensor_name].join(".")
|
||||
};
|
||||
let tensor = match &self.data.tensors {
|
||||
Tensors::SafeTensorWithRouting {
|
||||
routing,
|
||||
safetensors,
|
||||
} => {
|
||||
let index = routing.get(&path).ok_or_else(|| {
|
||||
Error::CannotFindTensor {
|
||||
path: path.to_string(),
|
||||
}
|
||||
.bt()
|
||||
})?;
|
||||
|
||||
let view = safetensors[*index].tensor(&path)?;
|
||||
let dtype = view.dtype();
|
||||
let mut shape = view.shape().to_vec();
|
||||
let size = shape[dim];
|
||||
let block_size = size / world_size;
|
||||
let start = rank * block_size;
|
||||
let stop = (rank + 1) * block_size;
|
||||
|
||||
let iterator = if dim == 0 {
|
||||
view.slice(start..stop).unwrap()
|
||||
} else if dim == 1 {
|
||||
view.slice((.., start..stop)).unwrap()
|
||||
} else {
|
||||
unimplemented!("Get sharded on dimensions != 0 or 1");
|
||||
};
|
||||
|
||||
shape[dim] = block_size;
|
||||
|
||||
Tensor::from_safetensors_slice(iterator, dtype, &shape, &data.device)?
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
};
|
||||
Ok(tensor)
|
||||
}
|
||||
pub fn get<S: Into<Shape>>(&self, s: S, tensor_name: &str) -> Result<Tensor> {
|
||||
let data = self.data.as_ref();
|
||||
let s: Shape = s.into();
|
||||
|
Reference in New Issue
Block a user