TP sharding v2

This commit is contained in:
Nicolas Patry
2023-07-21 15:10:51 +00:00
parent 209f06d7c3
commit 1735e4831e
9 changed files with 833 additions and 18 deletions

View File

@ -1,6 +1,7 @@
use crate::{DType, Device, Error, Result, Tensor, WithDType};
use safetensors::slice::SliceIterator;
use safetensors::tensor as st;
pub use safetensors::tensor::SafeTensors;
use safetensors::tensor::{Dtype, SafeTensors};
use std::borrow::Cow;
impl From<DType> for st::Dtype {
@ -63,15 +64,15 @@ impl Tensor {
}
}
fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
let v = view.data();
fn convert_slice<T: WithDType>(data: &[u8], shape: &[usize], device: &Device) -> Result<Tensor> {
let size_in_bytes = T::DTYPE.size_in_bytes();
let elem_count = v.len() / size_in_bytes;
if (v.as_ptr() as usize) % size_in_bytes == 0 {
let elem_count = data.len() / size_in_bytes;
if (data.as_ptr() as usize) % size_in_bytes == 0 {
// SAFETY This is safe because we just checked that this
// was correctly aligned.
let data: &[T] = unsafe { std::slice::from_raw_parts(v.as_ptr() as *const T, elem_count) };
Tensor::from_slice(data, view.shape(), device)
let data: &[T] =
unsafe { std::slice::from_raw_parts(data.as_ptr() as *const T, elem_count) };
Tensor::from_slice(data, shape, device)
} else {
// XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast
// Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access
@ -81,13 +82,17 @@ fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<
// We're downgrading the `c` pointer from T to u8, which removes alignment
// constraints.
unsafe {
std::ptr::copy_nonoverlapping(v.as_ptr(), c.as_mut_ptr() as *mut u8, v.len());
std::ptr::copy_nonoverlapping(data.as_ptr(), c.as_mut_ptr() as *mut u8, data.len());
c.set_len(elem_count)
}
Tensor::from_slice(&c, view.shape(), device)
Tensor::from_slice(&c, shape, device)
}
}
fn convert_<T: WithDType>(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
convert_slice::<T>(view.data(), view.shape(), device)
}
fn convert_back_<T: WithDType>(mut vs: Vec<T>) -> Vec<u8> {
let size_in_bytes = T::DTYPE.size_in_bytes();
let length = vs.len() * size_in_bytes;
@ -112,6 +117,26 @@ impl<'a> Load for st::TensorView<'a> {
}
}
impl Tensor {
pub fn from_safetensors_slice(
iterator: SliceIterator,
dtype: Dtype,
shape: &[usize],
device: &Device,
) -> Result<Self> {
let data: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
match dtype {
st::Dtype::U8 => convert_slice::<u8>(&data, shape, device),
st::Dtype::U32 => convert_slice::<u8>(&data, shape, device),
st::Dtype::BF16 => convert_slice::<half::bf16>(&data, shape, device),
st::Dtype::F16 => convert_slice::<half::f16>(&data, shape, device),
st::Dtype::F32 => convert_slice::<f32>(&data, shape, device),
st::Dtype::F64 => convert_slice::<f64>(&data, shape, device),
dtype => Err(Error::UnsupportedSafeTensorDtype(dtype)),
}
}
}
pub fn convert(view: &st::TensorView<'_>, device: &Device) -> Result<Tensor> {
match view.dtype() {
st::Dtype::U8 => convert_::<u8>(view, device),