From 2ae368e98eee085e5ceb2371c4271b39279f0cec Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 28 Jun 2023 21:06:56 +0100 Subject: [PATCH] Switch from a macro to a trait to make things more generic. --- candle-core/src/cpu_backend.rs | 181 +++++++++++++++++---------------- candle-core/src/dtype.rs | 17 ++-- 2 files changed, 107 insertions(+), 91 deletions(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index f1547b3c..990a4b70 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1,10 +1,8 @@ use crate::op::{BinaryOp, UnaryOp}; -use crate::{DType, Error, Layout, Result, Shape}; +use crate::{DType, Error, Layout, Result, Shape, WithDType}; use gemm::{gemm, Parallelism}; use half::{bf16, f16}; -// TODO: Think about whether we would be better off with a dtype and -// a buffer as an owned slice of bytes. // TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator + // intercept the oom errors to avoid panicking and provide a proper error. #[derive(Debug, Clone)] @@ -16,6 +14,24 @@ pub enum CpuStorage { F64(Vec), } +trait Map1 { + fn f( + &self, + vs: &[T], + layout: &Layout, + ) -> Result>; + + fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result { + match vs { + CpuStorage::U32(vs) => Ok(CpuStorage::U32(self.f(vs, layout)?)), + CpuStorage::BF16(vs) => Ok(CpuStorage::BF16(self.f(vs, layout)?)), + CpuStorage::F16(vs) => Ok(CpuStorage::F16(self.f(vs, layout)?)), + CpuStorage::F32(vs) => Ok(CpuStorage::F32(self.f(vs, layout)?)), + CpuStorage::F64(vs) => Ok(CpuStorage::F64(self.f(vs, layout)?)), + } + } +} + fn wcond( pred: &[u32], layout: &Layout, @@ -46,30 +62,30 @@ fn wcond( } } -macro_rules! map1 { - ($v: expr, $fn: ident, $( $args:expr ),*) => {{ - let v = match $v { - CpuStorage::BF16(__s) => CpuStorage::BF16($fn::(__s, $($args),*)?), - CpuStorage::F16(__s) => CpuStorage::F16($fn::(__s, $($args),*)?), - CpuStorage::F32(__s) => CpuStorage::F32($fn::(__s, $($args),*)?), - CpuStorage::F64(__s) => CpuStorage::F64($fn::(__s, $($args),*)?), - CpuStorage::U32(__s) => CpuStorage::U32($fn::(__s, $($args),*)?), - }; - Ok(v) - }}; +struct Sum<'a> { + dst_shape: &'a Shape, + sum_dims_and_stride: Vec<(usize, usize)>, } -fn sum_impl1( - src: &[T], - dst_shape: &Shape, - src_layout: &Layout, - to_dst_index: impl Fn(usize) -> usize, -) -> Result> { - let mut dst = vec![T::zero(); dst_shape.elem_count()]; - for (unstr_index, src_index) in src_layout.strided_index().enumerate() { - dst[to_dst_index(unstr_index)] += src[src_index]; +impl<'a> Map1 for Sum<'a> { + fn f( + &self, + src: &[T], + src_layout: &Layout, + ) -> Result> { + let mut dst = vec![T::zero(); self.dst_shape.elem_count()]; + for (unstr_index, src_index) in src_layout.strided_index().enumerate() { + let mut dst_index = unstr_index; + // Set the sum_dims indexes to 0. + for &(dim, stride) in self.sum_dims_and_stride.iter() { + // The compiler is able to optimize the following in a single divmod op. + let (pre, post) = (dst_index / stride, dst_index % stride); + dst_index = (pre / dim) * stride + post; + } + dst[dst_index] += src[src_index]; + } + Ok(dst) } - Ok(dst) } fn unary_map U>(vs: &[T], layout: &Layout, mut f: F) -> Vec { @@ -101,23 +117,48 @@ fn binary_map T>( } } -fn take_impl1(vs: &[T], ids: &[u32], layout: &Layout, rhs_l: &Layout) -> Result> { - // TODO: Optimize for the case where ids are contiguous. - let (vocab_size, hidden_size) = rhs_l.shape().r2()?; - let mut values = Vec::with_capacity(layout.shape().elem_count() * hidden_size); - for index in layout.strided_index() { - let index = ids[index].try_into()?; - if index >= vocab_size { - return Err(Error::InvalidIndex { - index, - vocab_size, - op: "take", - }); - } else { - values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]); - } +struct Affine(f64, f64); + +impl Map1 for Affine { + fn f( + &self, + vs: &[T], + layout: &Layout, + ) -> Result> { + let mul = T::from_f64(self.0); + let add = T::from_f64(self.1); + Ok(unary_map(vs, layout, |v| v * mul + add)) + } +} + +struct Embedding<'a> { + vocab_size: usize, + hidden_size: usize, + ids: &'a [u32], + ids_l: &'a Layout, +} + +impl<'a> Map1 for Embedding<'a> { + fn f(&self, vs: &[T], layout: &Layout) -> Result> { + // TODO: We assume that vs is contiguous here. + let vs = &vs[layout.start_offset()..]; + let mut values = Vec::with_capacity(self.ids_l.shape().elem_count() * self.hidden_size); + // TODO: Optimize for the case where ids are contiguous. + for index in self.ids_l.strided_index() { + let index = self.ids[index].try_into()?; + if index >= self.vocab_size { + return Err(Error::InvalidIndex { + index, + vocab_size: self.vocab_size, + op: "take", + }); + } else { + let hidden_size = self.hidden_size; + values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]); + } + } + Ok(values) } - Ok(values) } fn copy_strided_src_( @@ -348,19 +389,11 @@ impl CpuStorage { .iter() .map(|&d| (src_dims[d], src_dims[d + 1..].iter().product::())) .collect(); - let to_dst_index = |unstr_index: usize| { - // TODO: Optimize, the following does lots of slow division. - let mut dst_index = unstr_index; - // Set the sum_dims indexes to 0. - for &(dim, stride) in sum_dims_and_stride.iter() { - // The compiler is able to optimize the following in a single divmod op. - let (pre, post) = (dst_index / stride, dst_index % stride); - dst_index = (pre / dim) * stride + post; - } - dst_index - }; - // TODO: Maybe provide an implementation with higher precision accumulators? - map1!(self, sum_impl1, &dst_shape, layout, to_dst_index) + Sum { + dst_shape: &dst_shape, + sum_dims_and_stride, + } + .map(self, layout) } pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> { @@ -447,36 +480,7 @@ impl CpuStorage { } pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result { - match self { - Self::U32(storage) => { - let mul = mul as u32; - let add = add as u32; - let data = unary_map(storage, layout, |v| v * mul + add); - Ok(Self::U32(data)) - } - Self::BF16(storage) => { - let mul = bf16::from_f64(mul); - let add = bf16::from_f64(add); - let data = unary_map(storage, layout, |v| v * mul + add); - Ok(Self::BF16(data)) - } - Self::F16(storage) => { - let mul = f16::from_f64(mul); - let add = f16::from_f64(add); - let data = unary_map(storage, layout, |v| v * mul + add); - Ok(Self::F16(data)) - } - Self::F32(storage) => { - let mul = mul as f32; - let add = add as f32; - let data = unary_map(storage, layout, |v| v * mul + add); - Ok(Self::F32(data)) - } - Self::F64(storage) => { - let data = unary_map(storage, layout, |v| v * mul + add); - Ok(Self::F64(data)) - } - } + Affine(mul, add).map(self, layout) } pub(crate) fn unary_impl(&self, layout: &Layout) -> Result { @@ -605,9 +609,16 @@ impl CpuStorage { } } - pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result { + pub(crate) fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result { let ids = self.as_slice::()?; - map1!(rhs, take_impl1, ids, layout, rhs_l) + let (vocab_size, hidden_size) = rhs_l.shape().r2()?; + Embedding { + vocab_size, + hidden_size, + ids, + ids_l, + } + .map(rhs, rhs_l) } pub(crate) fn matmul( diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index fdbfdbba..87e3c9b4 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -34,6 +34,7 @@ impl DType { pub trait WithDType: Sized + Copy { const DTYPE: DType; + fn from_f64(v: f64) -> Self; fn to_cpu_storage_owned(data: Vec) -> CpuStorage; fn to_cpu_storage(data: &[Self]) -> CpuStorage { @@ -45,10 +46,14 @@ pub trait WithDType: Sized + Copy { } macro_rules! with_dtype { - ($ty:ty, $dtype:ident) => { + ($ty:ty, $dtype:ident, $from_f64:expr) => { impl WithDType for $ty { const DTYPE: DType = DType::$dtype; + fn from_f64(v: f64) -> Self { + $from_f64(v) + } + fn to_cpu_storage_owned(data: Vec) -> CpuStorage { CpuStorage::$dtype(data) } @@ -77,8 +82,8 @@ macro_rules! with_dtype { } }; } -with_dtype!(u32, U32); -with_dtype!(half::f16, F16); -with_dtype!(half::bf16, BF16); -with_dtype!(f32, F32); -with_dtype!(f64, F64); +with_dtype!(u32, U32, |v: f64| v as u32); +with_dtype!(half::f16, F16, half::f16::from_f64); +with_dtype!(half::bf16, BF16, half::bf16::from_f64); +with_dtype!(f32, F32, |v: f64| v as f32); +with_dtype!(f64, F64, |v: f64| v);