mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Merge pull request #216 from LaurentMazare/llama_multiprocess2
TP sharding v2
This commit is contained in:
@ -79,6 +79,13 @@ pub enum Error {
|
||||
nth_shape: Shape,
|
||||
},
|
||||
|
||||
#[error("Cannot divide tensor of shape {shape:?} equally along dim {dim} into {n_parts}")]
|
||||
ShapeMismatchSplit {
|
||||
shape: Shape,
|
||||
dim: usize,
|
||||
n_parts: usize,
|
||||
},
|
||||
|
||||
#[error("{op} can only be performed on a single dimension")]
|
||||
OnlySingleDimension { op: &'static str, dims: Vec<usize> },
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
use crate::{DType, Device, Error, Result, Tensor, WithDType};
|
||||
use safetensors::tensor as st;
|
||||
pub use safetensors::tensor::SafeTensors;
|
||||
use safetensors::tensor::SafeTensors;
|
||||
use std::borrow::Cow;
|
||||
|
||||
impl From<DType> for st::Dtype {
|
||||
@ -63,15 +63,15 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
let v = view.data();
|
||||
fn convert_slice<T: WithDType>(data: &[u8], shape: &[usize], device: &Device) -> Result<Tensor> {
|
||||
let size_in_bytes = T::DTYPE.size_in_bytes();
|
||||
let elem_count = v.len() / size_in_bytes;
|
||||
if (v.as_ptr() as usize) % size_in_bytes == 0 {
|
||||
let elem_count = data.len() / size_in_bytes;
|
||||
if (data.as_ptr() as usize) % size_in_bytes == 0 {
|
||||
// SAFETY This is safe because we just checked that this
|
||||
// was correctly aligned.
|
||||
let data: &[T] = unsafe { std::slice::from_raw_parts(v.as_ptr() as *const T, elem_count) };
|
||||
Tensor::from_slice(data, view.shape(), device)
|
||||
let data: &[T] =
|
||||
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
|
||||
Tensor::from_slice(data, shape, device)
|
||||
} else {
|
||||
// XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast
|
||||
// Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access
|
||||
@ -81,13 +81,17 @@ fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<
|
||||
// We're downgrading the `c` pointer from T to u8, which removes alignment
|
||||
// constraints.
|
||||
unsafe {
|
||||
std::ptr::copy_nonoverlapping(v.as_ptr(), c.as_mut_ptr() as *mut u8, v.len());
|
||||
std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
|
||||
c.set_len(elem_count)
|
||||
}
|
||||
Tensor::from_slice(&c, view.shape(), device)
|
||||
Tensor::from_slice(&c, shape, device)
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
convert_slice::<T>(view.data(), view.shape(), device)
|
||||
}
|
||||
|
||||
fn convert_back_<T: WithDType>(mut vs: Vec<T>) -> Vec<u8> {
|
||||
let size_in_bytes = T::DTYPE.size_in_bytes();
|
||||
let length = vs.len() * size_in_bytes;
|
||||
@ -112,7 +116,25 @@ impl<'a> Load for st::TensorView<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
impl Tensor {
|
||||
pub fn from_raw_buffer(
|
||||
data: &[u8],
|
||||
dtype: DType,
|
||||
shape: &[usize],
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
match dtype {
|
||||
DType::U8 => convert_slice::<u8>(data, shape, device),
|
||||
DType::U32 => convert_slice::<u32>(data, shape, device),
|
||||
DType::BF16 => convert_slice::<half::bf16>(data, shape, device),
|
||||
DType::F16 => convert_slice::<half::f16>(data, shape, device),
|
||||
DType::F32 => convert_slice::<f32>(data, shape, device),
|
||||
DType::F64 => convert_slice::<f64>(data, shape, device),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
match view.dtype() {
|
||||
st::Dtype::U8 => convert_::<u8>(view, device),
|
||||
st::Dtype::U32 => convert_::<u8>(view, device),
|
||||
@ -124,7 +146,7 @@ pub fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
|
||||
fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {
|
||||
// TODO: This makes an unnecessary copy when the tensor is on the cpu.
|
||||
let tensor = tensor.flatten_all()?;
|
||||
match tensor.dtype() {
|
||||
|
Reference in New Issue
Block a user