mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Switch from a macro to a trait to make things more generic.
This commit is contained in:
@ -1,10 +1,8 @@
|
||||
use crate::op::{BinaryOp, UnaryOp};
|
||||
use crate::{DType, Error, Layout, Result, Shape};
|
||||
use crate::{DType, Error, Layout, Result, Shape, WithDType};
|
||||
use gemm::{gemm, Parallelism};
|
||||
use half::{bf16, f16};
|
||||
|
||||
// TODO: Think about whether we would be better off with a dtype and
|
||||
// a buffer as an owned slice of bytes.
|
||||
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
||||
// intercept the oom errors to avoid panicking and provide a proper error.
|
||||
#[derive(Debug, Clone)]
|
||||
@ -16,6 +14,24 @@ pub enum CpuStorage {
|
||||
F64(Vec<f64>),
|
||||
}
|
||||
|
||||
trait Map1 {
|
||||
fn f<T: WithDType + Copy + num_traits::NumAssign>(
|
||||
&self,
|
||||
vs: &[T],
|
||||
layout: &Layout,
|
||||
) -> Result<Vec<T>>;
|
||||
|
||||
fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
|
||||
match vs {
|
||||
CpuStorage::U32(vs) => Ok(CpuStorage::U32(self.f(vs, layout)?)),
|
||||
CpuStorage::BF16(vs) => Ok(CpuStorage::BF16(self.f(vs, layout)?)),
|
||||
CpuStorage::F16(vs) => Ok(CpuStorage::F16(self.f(vs, layout)?)),
|
||||
CpuStorage::F32(vs) => Ok(CpuStorage::F32(self.f(vs, layout)?)),
|
||||
CpuStorage::F64(vs) => Ok(CpuStorage::F64(self.f(vs, layout)?)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn wcond<T: Copy>(
|
||||
pred: &[u32],
|
||||
layout: &Layout,
|
||||
@ -46,30 +62,30 @@ fn wcond<T: Copy>(
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! map1 {
|
||||
($v: expr, $fn: ident, $( $args:expr ),*) => {{
|
||||
let v = match $v {
|
||||
CpuStorage::BF16(__s) => CpuStorage::BF16($fn::<bf16>(__s, $($args),*)?),
|
||||
CpuStorage::F16(__s) => CpuStorage::F16($fn::<f16>(__s, $($args),*)?),
|
||||
CpuStorage::F32(__s) => CpuStorage::F32($fn::<f32>(__s, $($args),*)?),
|
||||
CpuStorage::F64(__s) => CpuStorage::F64($fn::<f64>(__s, $($args),*)?),
|
||||
CpuStorage::U32(__s) => CpuStorage::U32($fn::<u32>(__s, $($args),*)?),
|
||||
};
|
||||
Ok(v)
|
||||
}};
|
||||
struct Sum<'a> {
|
||||
dst_shape: &'a Shape,
|
||||
sum_dims_and_stride: Vec<(usize, usize)>,
|
||||
}
|
||||
|
||||
fn sum_impl1<T: Copy + num_traits::NumAssign>(
|
||||
src: &[T],
|
||||
dst_shape: &Shape,
|
||||
src_layout: &Layout,
|
||||
to_dst_index: impl Fn(usize) -> usize,
|
||||
) -> Result<Vec<T>> {
|
||||
let mut dst = vec![T::zero(); dst_shape.elem_count()];
|
||||
for (unstr_index, src_index) in src_layout.strided_index().enumerate() {
|
||||
dst[to_dst_index(unstr_index)] += src[src_index];
|
||||
impl<'a> Map1 for Sum<'a> {
|
||||
fn f<T: WithDType + Copy + num_traits::NumAssign>(
|
||||
&self,
|
||||
src: &[T],
|
||||
src_layout: &Layout,
|
||||
) -> Result<Vec<T>> {
|
||||
let mut dst = vec![T::zero(); self.dst_shape.elem_count()];
|
||||
for (unstr_index, src_index) in src_layout.strided_index().enumerate() {
|
||||
let mut dst_index = unstr_index;
|
||||
// Set the sum_dims indexes to 0.
|
||||
for &(dim, stride) in self.sum_dims_and_stride.iter() {
|
||||
// The compiler is able to optimize the following in a single divmod op.
|
||||
let (pre, post) = (dst_index / stride, dst_index % stride);
|
||||
dst_index = (pre / dim) * stride + post;
|
||||
}
|
||||
dst[dst_index] += src[src_index];
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
Ok(dst)
|
||||
}
|
||||
|
||||
fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(vs: &[T], layout: &Layout, mut f: F) -> Vec<U> {
|
||||
@ -101,23 +117,48 @@ fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
|
||||
}
|
||||
}
|
||||
|
||||
fn take_impl1<T: Copy>(vs: &[T], ids: &[u32], layout: &Layout, rhs_l: &Layout) -> Result<Vec<T>> {
|
||||
// TODO: Optimize for the case where ids are contiguous.
|
||||
let (vocab_size, hidden_size) = rhs_l.shape().r2()?;
|
||||
let mut values = Vec::with_capacity(layout.shape().elem_count() * hidden_size);
|
||||
for index in layout.strided_index() {
|
||||
let index = ids[index].try_into()?;
|
||||
if index >= vocab_size {
|
||||
return Err(Error::InvalidIndex {
|
||||
index,
|
||||
vocab_size,
|
||||
op: "take",
|
||||
});
|
||||
} else {
|
||||
values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]);
|
||||
}
|
||||
struct Affine(f64, f64);
|
||||
|
||||
impl Map1 for Affine {
|
||||
fn f<T: WithDType + Copy + num_traits::NumAssign>(
|
||||
&self,
|
||||
vs: &[T],
|
||||
layout: &Layout,
|
||||
) -> Result<Vec<T>> {
|
||||
let mul = T::from_f64(self.0);
|
||||
let add = T::from_f64(self.1);
|
||||
Ok(unary_map(vs, layout, |v| v * mul + add))
|
||||
}
|
||||
}
|
||||
|
||||
struct Embedding<'a> {
|
||||
vocab_size: usize,
|
||||
hidden_size: usize,
|
||||
ids: &'a [u32],
|
||||
ids_l: &'a Layout,
|
||||
}
|
||||
|
||||
impl<'a> Map1 for Embedding<'a> {
|
||||
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||
// TODO: We assume that vs is contiguous here.
|
||||
let vs = &vs[layout.start_offset()..];
|
||||
let mut values = Vec::with_capacity(self.ids_l.shape().elem_count() * self.hidden_size);
|
||||
// TODO: Optimize for the case where ids are contiguous.
|
||||
for index in self.ids_l.strided_index() {
|
||||
let index = self.ids[index].try_into()?;
|
||||
if index >= self.vocab_size {
|
||||
return Err(Error::InvalidIndex {
|
||||
index,
|
||||
vocab_size: self.vocab_size,
|
||||
op: "take",
|
||||
});
|
||||
} else {
|
||||
let hidden_size = self.hidden_size;
|
||||
values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]);
|
||||
}
|
||||
}
|
||||
Ok(values)
|
||||
}
|
||||
Ok(values)
|
||||
}
|
||||
|
||||
fn copy_strided_src_<T: Copy + std::fmt::Display>(
|
||||
@ -348,19 +389,11 @@ impl CpuStorage {
|
||||
.iter()
|
||||
.map(|&d| (src_dims[d], src_dims[d + 1..].iter().product::<usize>()))
|
||||
.collect();
|
||||
let to_dst_index = |unstr_index: usize| {
|
||||
// TODO: Optimize, the following does lots of slow division.
|
||||
let mut dst_index = unstr_index;
|
||||
// Set the sum_dims indexes to 0.
|
||||
for &(dim, stride) in sum_dims_and_stride.iter() {
|
||||
// The compiler is able to optimize the following in a single divmod op.
|
||||
let (pre, post) = (dst_index / stride, dst_index % stride);
|
||||
dst_index = (pre / dim) * stride + post;
|
||||
}
|
||||
dst_index
|
||||
};
|
||||
// TODO: Maybe provide an implementation with higher precision accumulators?
|
||||
map1!(self, sum_impl1, &dst_shape, layout, to_dst_index)
|
||||
Sum {
|
||||
dst_shape: &dst_shape,
|
||||
sum_dims_and_stride,
|
||||
}
|
||||
.map(self, layout)
|
||||
}
|
||||
|
||||
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
|
||||
@ -447,36 +480,7 @@ impl CpuStorage {
|
||||
}
|
||||
|
||||
pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
||||
match self {
|
||||
Self::U32(storage) => {
|
||||
let mul = mul as u32;
|
||||
let add = add as u32;
|
||||
let data = unary_map(storage, layout, |v| v * mul + add);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
Self::BF16(storage) => {
|
||||
let mul = bf16::from_f64(mul);
|
||||
let add = bf16::from_f64(add);
|
||||
let data = unary_map(storage, layout, |v| v * mul + add);
|
||||
Ok(Self::BF16(data))
|
||||
}
|
||||
Self::F16(storage) => {
|
||||
let mul = f16::from_f64(mul);
|
||||
let add = f16::from_f64(add);
|
||||
let data = unary_map(storage, layout, |v| v * mul + add);
|
||||
Ok(Self::F16(data))
|
||||
}
|
||||
Self::F32(storage) => {
|
||||
let mul = mul as f32;
|
||||
let add = add as f32;
|
||||
let data = unary_map(storage, layout, |v| v * mul + add);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
Self::F64(storage) => {
|
||||
let data = unary_map(storage, layout, |v| v * mul + add);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
}
|
||||
Affine(mul, add).map(self, layout)
|
||||
}
|
||||
|
||||
pub(crate) fn unary_impl<B: UnaryOp>(&self, layout: &Layout) -> Result<Self> {
|
||||
@ -605,9 +609,16 @@ impl CpuStorage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
||||
pub(crate) fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
||||
let ids = self.as_slice::<u32>()?;
|
||||
map1!(rhs, take_impl1, ids, layout, rhs_l)
|
||||
let (vocab_size, hidden_size) = rhs_l.shape().r2()?;
|
||||
Embedding {
|
||||
vocab_size,
|
||||
hidden_size,
|
||||
ids,
|
||||
ids_l,
|
||||
}
|
||||
.map(rhs, rhs_l)
|
||||
}
|
||||
|
||||
pub(crate) fn matmul(
|
||||
|
@ -34,6 +34,7 @@ impl DType {
|
||||
pub trait WithDType: Sized + Copy {
|
||||
const DTYPE: DType;
|
||||
|
||||
fn from_f64(v: f64) -> Self;
|
||||
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
|
||||
|
||||
fn to_cpu_storage(data: &[Self]) -> CpuStorage {
|
||||
@ -45,10 +46,14 @@ pub trait WithDType: Sized + Copy {
|
||||
}
|
||||
|
||||
macro_rules! with_dtype {
|
||||
($ty:ty, $dtype:ident) => {
|
||||
($ty:ty, $dtype:ident, $from_f64:expr) => {
|
||||
impl WithDType for $ty {
|
||||
const DTYPE: DType = DType::$dtype;
|
||||
|
||||
fn from_f64(v: f64) -> Self {
|
||||
$from_f64(v)
|
||||
}
|
||||
|
||||
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
|
||||
CpuStorage::$dtype(data)
|
||||
}
|
||||
@ -77,8 +82,8 @@ macro_rules! with_dtype {
|
||||
}
|
||||
};
|
||||
}
|
||||
with_dtype!(u32, U32);
|
||||
with_dtype!(half::f16, F16);
|
||||
with_dtype!(half::bf16, BF16);
|
||||
with_dtype!(f32, F32);
|
||||
with_dtype!(f64, F64);
|
||||
with_dtype!(u32, U32, |v: f64| v as u32);
|
||||
with_dtype!(half::f16, F16, half::f16::from_f64);
|
||||
with_dtype!(half::bf16, BF16, half::bf16::from_f64);
|
||||
with_dtype!(f32, F32, |v: f64| v as f32);
|
||||
with_dtype!(f64, F64, |v: f64| v);
|
||||
|
Reference in New Issue
Block a user