From 2ae368e98eee085e5ceb2371c4271b39279f0cec Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 28 Jun 2023 21:06:56 +0100 Subject: [PATCH 1/5] Switch from a macro to a trait to make things more generic. --- candle-core/src/cpu_backend.rs | 181 +++++++++++++++++---------------- candle-core/src/dtype.rs | 17 ++-- 2 files changed, 107 insertions(+), 91 deletions(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index f1547b3c..990a4b70 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1,10 +1,8 @@ use crate::op::{BinaryOp, UnaryOp}; -use crate::{DType, Error, Layout, Result, Shape}; +use crate::{DType, Error, Layout, Result, Shape, WithDType}; use gemm::{gemm, Parallelism}; use half::{bf16, f16}; -// TODO: Think about whether we would be better off with a dtype and -// a buffer as an owned slice of bytes. // TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator + // intercept the oom errors to avoid panicking and provide a proper error. #[derive(Debug, Clone)] @@ -16,6 +14,24 @@ pub enum CpuStorage { F64(Vec), } +trait Map1 { + fn f( + &self, + vs: &[T], + layout: &Layout, + ) -> Result>; + + fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result { + match vs { + CpuStorage::U32(vs) => Ok(CpuStorage::U32(self.f(vs, layout)?)), + CpuStorage::BF16(vs) => Ok(CpuStorage::BF16(self.f(vs, layout)?)), + CpuStorage::F16(vs) => Ok(CpuStorage::F16(self.f(vs, layout)?)), + CpuStorage::F32(vs) => Ok(CpuStorage::F32(self.f(vs, layout)?)), + CpuStorage::F64(vs) => Ok(CpuStorage::F64(self.f(vs, layout)?)), + } + } +} + fn wcond( pred: &[u32], layout: &Layout, @@ -46,30 +62,30 @@ fn wcond( } } -macro_rules! map1 { - ($v: expr, $fn: ident, $( $args:expr ),*) => {{ - let v = match $v { - CpuStorage::BF16(__s) => CpuStorage::BF16($fn::(__s, $($args),*)?), - CpuStorage::F16(__s) => CpuStorage::F16($fn::(__s, $($args),*)?), - CpuStorage::F32(__s) => CpuStorage::F32($fn::(__s, $($args),*)?), - CpuStorage::F64(__s) => CpuStorage::F64($fn::(__s, $($args),*)?), - CpuStorage::U32(__s) => CpuStorage::U32($fn::(__s, $($args),*)?), - }; - Ok(v) - }}; +struct Sum<'a> { + dst_shape: &'a Shape, + sum_dims_and_stride: Vec<(usize, usize)>, } -fn sum_impl1( - src: &[T], - dst_shape: &Shape, - src_layout: &Layout, - to_dst_index: impl Fn(usize) -> usize, -) -> Result> { - let mut dst = vec![T::zero(); dst_shape.elem_count()]; - for (unstr_index, src_index) in src_layout.strided_index().enumerate() { - dst[to_dst_index(unstr_index)] += src[src_index]; +impl<'a> Map1 for Sum<'a> { + fn f( + &self, + src: &[T], + src_layout: &Layout, + ) -> Result> { + let mut dst = vec![T::zero(); self.dst_shape.elem_count()]; + for (unstr_index, src_index) in src_layout.strided_index().enumerate() { + let mut dst_index = unstr_index; + // Set the sum_dims indexes to 0. + for &(dim, stride) in self.sum_dims_and_stride.iter() { + // The compiler is able to optimize the following in a single divmod op. + let (pre, post) = (dst_index / stride, dst_index % stride); + dst_index = (pre / dim) * stride + post; + } + dst[dst_index] += src[src_index]; + } + Ok(dst) } - Ok(dst) } fn unary_map U>(vs: &[T], layout: &Layout, mut f: F) -> Vec { @@ -101,23 +117,48 @@ fn binary_map T>( } } -fn take_impl1(vs: &[T], ids: &[u32], layout: &Layout, rhs_l: &Layout) -> Result> { - // TODO: Optimize for the case where ids are contiguous. - let (vocab_size, hidden_size) = rhs_l.shape().r2()?; - let mut values = Vec::with_capacity(layout.shape().elem_count() * hidden_size); - for index in layout.strided_index() { - let index = ids[index].try_into()?; - if index >= vocab_size { - return Err(Error::InvalidIndex { - index, - vocab_size, - op: "take", - }); - } else { - values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]); - } +struct Affine(f64, f64); + +impl Map1 for Affine { + fn f( + &self, + vs: &[T], + layout: &Layout, + ) -> Result> { + let mul = T::from_f64(self.0); + let add = T::from_f64(self.1); + Ok(unary_map(vs, layout, |v| v * mul + add)) + } +} + +struct Embedding<'a> { + vocab_size: usize, + hidden_size: usize, + ids: &'a [u32], + ids_l: &'a Layout, +} + +impl<'a> Map1 for Embedding<'a> { + fn f(&self, vs: &[T], layout: &Layout) -> Result> { + // TODO: We assume that vs is contiguous here. + let vs = &vs[layout.start_offset()..]; + let mut values = Vec::with_capacity(self.ids_l.shape().elem_count() * self.hidden_size); + // TODO: Optimize for the case where ids are contiguous. + for index in self.ids_l.strided_index() { + let index = self.ids[index].try_into()?; + if index >= self.vocab_size { + return Err(Error::InvalidIndex { + index, + vocab_size: self.vocab_size, + op: "take", + }); + } else { + let hidden_size = self.hidden_size; + values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]); + } + } + Ok(values) } - Ok(values) } fn copy_strided_src_( @@ -348,19 +389,11 @@ impl CpuStorage { .iter() .map(|&d| (src_dims[d], src_dims[d + 1..].iter().product::())) .collect(); - let to_dst_index = |unstr_index: usize| { - // TODO: Optimize, the following does lots of slow division. - let mut dst_index = unstr_index; - // Set the sum_dims indexes to 0. - for &(dim, stride) in sum_dims_and_stride.iter() { - // The compiler is able to optimize the following in a single divmod op. - let (pre, post) = (dst_index / stride, dst_index % stride); - dst_index = (pre / dim) * stride + post; - } - dst_index - }; - // TODO: Maybe provide an implementation with higher precision accumulators? - map1!(self, sum_impl1, &dst_shape, layout, to_dst_index) + Sum { + dst_shape: &dst_shape, + sum_dims_and_stride, + } + .map(self, layout) } pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> { @@ -447,36 +480,7 @@ impl CpuStorage { } pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result { - match self { - Self::U32(storage) => { - let mul = mul as u32; - let add = add as u32; - let data = unary_map(storage, layout, |v| v * mul + add); - Ok(Self::U32(data)) - } - Self::BF16(storage) => { - let mul = bf16::from_f64(mul); - let add = bf16::from_f64(add); - let data = unary_map(storage, layout, |v| v * mul + add); - Ok(Self::BF16(data)) - } - Self::F16(storage) => { - let mul = f16::from_f64(mul); - let add = f16::from_f64(add); - let data = unary_map(storage, layout, |v| v * mul + add); - Ok(Self::F16(data)) - } - Self::F32(storage) => { - let mul = mul as f32; - let add = add as f32; - let data = unary_map(storage, layout, |v| v * mul + add); - Ok(Self::F32(data)) - } - Self::F64(storage) => { - let data = unary_map(storage, layout, |v| v * mul + add); - Ok(Self::F64(data)) - } - } + Affine(mul, add).map(self, layout) } pub(crate) fn unary_impl(&self, layout: &Layout) -> Result { @@ -605,9 +609,16 @@ impl CpuStorage { } } - pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result { + pub(crate) fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result { let ids = self.as_slice::()?; - map1!(rhs, take_impl1, ids, layout, rhs_l) + let (vocab_size, hidden_size) = rhs_l.shape().r2()?; + Embedding { + vocab_size, + hidden_size, + ids, + ids_l, + } + .map(rhs, rhs_l) } pub(crate) fn matmul( diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index fdbfdbba..87e3c9b4 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -34,6 +34,7 @@ impl DType { pub trait WithDType: Sized + Copy { const DTYPE: DType; + fn from_f64(v: f64) -> Self; fn to_cpu_storage_owned(data: Vec) -> CpuStorage; fn to_cpu_storage(data: &[Self]) -> CpuStorage { @@ -45,10 +46,14 @@ pub trait WithDType: Sized + Copy { } macro_rules! with_dtype { - ($ty:ty, $dtype:ident) => { + ($ty:ty, $dtype:ident, $from_f64:expr) => { impl WithDType for $ty { const DTYPE: DType = DType::$dtype; + fn from_f64(v: f64) -> Self { + $from_f64(v) + } + fn to_cpu_storage_owned(data: Vec) -> CpuStorage { CpuStorage::$dtype(data) } @@ -77,8 +82,8 @@ macro_rules! with_dtype { } }; } -with_dtype!(u32, U32); -with_dtype!(half::f16, F16); -with_dtype!(half::bf16, BF16); -with_dtype!(f32, F32); -with_dtype!(f64, F64); +with_dtype!(u32, U32, |v: f64| v as u32); +with_dtype!(half::f16, F16, half::f16::from_f64); +with_dtype!(half::bf16, BF16, half::bf16::from_f64); +with_dtype!(f32, F32, |v: f64| v as f32); +with_dtype!(f64, F64, |v: f64| v); From 46c07b924c90dc1fb6ff2d432e6fe16c3da09d72 Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 28 Jun 2023 21:10:54 +0100 Subject: [PATCH 2/5] Tweak some comment. --- candle-core/src/cpu_backend.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 990a4b70..7409a90a 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -397,7 +397,7 @@ impl CpuStorage { } pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> { - // [self] stores data in a contiguous way. + // [self] stores data in a contiguous way starting at offset 0. let dims = shape.dims(); let elem_per_slice = dims[dim]; let prod_pre_dim = dims[..dim].iter().product(); From c583ee0f2cd62d1d820a57e248d5851c5f18145d Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 28 Jun 2023 21:38:01 +0100 Subject: [PATCH 3/5] Add map2. --- candle-core/src/cpu_backend.rs | 259 ++++++++++++++++----------------- 1 file changed, 127 insertions(+), 132 deletions(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 7409a90a..7170e470 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -32,33 +32,66 @@ trait Map1 { } } -fn wcond( - pred: &[u32], - layout: &Layout, - t: &[T], - layout_t: &Layout, - f: &[T], - layout_f: &Layout, -) -> Vec { - match ( - layout.contiguous_offsets(), - layout_t.contiguous_offsets(), - layout_f.contiguous_offsets(), - ) { - (Some((o1, o2)), Some((o_t1, o_t2)), Some((o_f1, o_f2))) => { - let pred = &pred[o1..o2]; - let t = &t[o_t1..o_t2]; - let f = &f[o_f1..o_f2]; - pred.iter() - .zip(t.iter().zip(f.iter())) - .map(|(&p, (&t, &f))| if p > 0 { t } else { f }) - .collect::>() +type C = CpuStorage; +trait Map2 { + const OP: &'static str; + fn f( + &self, + v1: &[T], + l1: &Layout, + v2: &[T], + l2: &Layout, + ) -> Result>; + + fn map( + &self, + v1: &CpuStorage, + l1: &Layout, + v2: &CpuStorage, + l2: &Layout, + ) -> Result { + match (v1, v2) { + (C::U32(v1), C::U32(v2)) => Ok(C::U32(self.f(v1, l1, v2, l2)?)), + (C::BF16(v1), C::BF16(v2)) => Ok(C::BF16(self.f(v1, l1, v2, l2)?)), + (C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)), + (C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)), + (C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)), + _ => Err(Error::DTypeMismatchBinaryOp { + lhs: v1.dtype(), + rhs: v2.dtype(), + op: Self::OP, + }), } - _ => layout - .strided_index() - .zip(layout_t.strided_index().zip(layout_f.strided_index())) - .map(|(i_p, (i_t, i_f))| if pred[i_p] > 0 { t[i_t] } else { f[i_f] }) - .collect::>(), + } +} + +struct WCond<'a>(&'a [u32], &'a Layout); + +impl<'a> Map2 for WCond<'a> { + const OP: &'static str = "where"; + fn f(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result> { + let vs = match ( + self.1.contiguous_offsets(), + t_l.contiguous_offsets(), + f_l.contiguous_offsets(), + ) { + (Some((o1, o2)), Some((o_t1, o_t2)), Some((o_f1, o_f2))) => { + let pred = &self.0[o1..o2]; + let t = &t[o_t1..o_t2]; + let f = &f[o_f1..o_f2]; + pred.iter() + .zip(t.iter().zip(f.iter())) + .map(|(&p, (&t, &f))| if p > 0 { t } else { f }) + .collect::>() + } + _ => self + .1 + .strided_index() + .zip(t_l.strided_index().zip(f_l.strided_index())) + .map(|(i_p, (i_t, i_f))| if self.0[i_p] > 0 { t[i_t] } else { f[i_f] }) + .collect::>(), + }; + Ok(vs) } } @@ -184,73 +217,79 @@ fn copy_strided_src_( } } -fn matmul( - lhs: &[T], - rhs: &[T], - (b, m, n, k): (usize, usize, usize, usize), - lhs_l: &Layout, - rhs_l: &Layout, -) -> Result> { - let lhs = &lhs[lhs_l.start_offset()..]; - let rhs = &rhs[rhs_l.start_offset()..]; - let a_skip: usize = m * k; - let b_skip: usize = n * k; - let c_skip: usize = m * n; +struct MatMul((usize, usize, usize, usize)); - let lhs_stride = lhs_l.stride(); - let rhs_stride = rhs_l.stride(); - let rank = lhs_stride.len(); - let lhs_cs = lhs_stride[rank - 1]; - let lhs_rs = lhs_stride[rank - 2]; +impl Map2 for MatMul { + const OP: &'static str = "mat_mul"; + fn f( + &self, + lhs: &[T], + lhs_l: &Layout, + rhs: &[T], + rhs_l: &Layout, + ) -> Result> { + let (b, m, n, k) = self.0; + let lhs = &lhs[lhs_l.start_offset()..]; + let rhs = &rhs[rhs_l.start_offset()..]; + let a_skip: usize = m * k; + let b_skip: usize = n * k; + let c_skip: usize = m * n; - let rhs_cs = rhs_stride[rank - 1]; - let rhs_rs = rhs_stride[rank - 2]; + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); + let rank = lhs_stride.len(); + let lhs_cs = lhs_stride[rank - 1]; + let lhs_rs = lhs_stride[rank - 2]; - if lhs_stride.len() > 2 { - let lhs_batch_stride = &lhs_stride[..rank - 2]; - let rhs_batch_stride = &rhs_stride[..rank - 2]; + let rhs_cs = rhs_stride[rank - 1]; + let rhs_rs = rhs_stride[rank - 2]; - if lhs_batch_stride != [a_skip] || rhs_batch_stride != [b_skip] { - // Temporary error before we support abitrary striding. - return Err(Error::UnexpectedStriding); + if lhs_stride.len() > 2 { + let lhs_batch_stride = &lhs_stride[..rank - 2]; + let rhs_batch_stride = &rhs_stride[..rank - 2]; + + if lhs_batch_stride != [a_skip] || rhs_batch_stride != [b_skip] { + // Temporary error before we support abitrary striding. + return Err(Error::UnexpectedStriding); + } } - } - let dst_shape: Shape = (m, n).into(); - let dst_strides = dst_shape.stride_contiguous(); - let dst_rs = dst_strides[0]; - let dst_cs = dst_strides[1]; + let dst_shape: Shape = (m, n).into(); + let dst_strides = dst_shape.stride_contiguous(); + let dst_rs = dst_strides[0]; + let dst_cs = dst_strides[1]; - let mut dst = vec![T::zero(); b * m * n]; - for step in 0..b { - let lhs_p = &lhs[step * a_skip..]; - let rhs_p = &rhs[step * b_skip..]; - let dst_p = &mut dst[step * c_skip..]; - unsafe { - gemm( - /* m: usize = */ m, - /* n: usize = */ n, - /* k: usize = */ k, - /* dst: *mut T = */ dst_p.as_mut_ptr(), - /* dst_cs: isize = */ dst_cs as isize, - /* dst_rs: isize = */ dst_rs as isize, - /* read_dst: bool = */ false, - /* lhs: *const T = */ lhs_p.as_ptr(), - /* lhs_cs: isize = */ lhs_cs as isize, - /* lhs_rs: isize = */ lhs_rs as isize, - /* rhs: *const T = */ rhs_p.as_ptr(), - /* rhs_cs: isize = */ rhs_cs as isize, - /* rhs_rs: isize = */ rhs_rs as isize, - /* alpha: T = */ T::zero(), - /* beta: T = */ T::one(), - /* conj_dst: bool = */ false, - /* conj_lhs: bool = */ false, - /* conj_rhs: bool = */ false, - Parallelism::Rayon(crate::utils::get_num_threads()), - ) + let mut dst = vec![T::zero(); b * m * n]; + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + gemm( + /* m: usize = */ m, + /* n: usize = */ n, + /* k: usize = */ k, + /* dst: *mut T = */ dst_p.as_mut_ptr(), + /* dst_cs: isize = */ dst_cs as isize, + /* dst_rs: isize = */ dst_rs as isize, + /* read_dst: bool = */ false, + /* lhs: *const T = */ lhs_p.as_ptr(), + /* lhs_cs: isize = */ lhs_cs as isize, + /* lhs_rs: isize = */ lhs_rs as isize, + /* rhs: *const T = */ rhs_p.as_ptr(), + /* rhs_cs: isize = */ rhs_cs as isize, + /* rhs_rs: isize = */ rhs_rs as isize, + /* alpha: T = */ T::zero(), + /* beta: T = */ T::one(), + /* conj_dst: bool = */ false, + /* conj_lhs: bool = */ false, + /* conj_rhs: bool = */ false, + Parallelism::Rayon(crate::utils::get_num_threads()), + ) + } } + Ok(dst) } - Ok(dst) } impl CpuStorage { @@ -574,39 +613,13 @@ impl CpuStorage { &self, layout: &Layout, t: &Self, - layout_t: &Layout, + t_l: &Layout, f: &Self, - layout_f: &Layout, + f_l: &Layout, ) -> Result { // TODO: Support types that could be casted to a boolean. let pred = self.as_slice::()?; - match (t, f) { - (Self::BF16(t), Self::BF16(f)) => { - let data = wcond(pred, layout, t, layout_t, f, layout_f); - Ok(Self::BF16(data)) - } - (Self::F16(t), Self::F16(f)) => { - let data = wcond(pred, layout, t, layout_t, f, layout_f); - Ok(Self::F16(data)) - } - (Self::F32(t), Self::F32(f)) => { - let data = wcond(pred, layout, t, layout_t, f, layout_f); - Ok(Self::F32(data)) - } - (Self::F64(t), Self::F64(f)) => { - let data = wcond(pred, layout, t, layout_t, f, layout_f); - Ok(Self::F64(data)) - } - (Self::U32(t), Self::U32(f)) => { - let data = wcond(pred, layout, t, layout_t, f, layout_f); - Ok(Self::U32(data)) - } - _ => Err(Error::DTypeMismatchBinaryOp { - lhs: t.dtype(), - rhs: f.dtype(), - op: "where_cond", - }), - } + WCond(pred, layout).map(t, t_l, f, f_l) } pub(crate) fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result { @@ -628,25 +641,7 @@ impl CpuStorage { lhs_l: &Layout, rhs_l: &Layout, ) -> Result { - match (self, rhs) { - (CpuStorage::F16(lhs), CpuStorage::F16(rhs)) => { - let dst = matmul(lhs, rhs, bmnk, lhs_l, rhs_l)?; - Ok(Self::F16(dst)) - } - (CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => { - let dst = matmul(lhs, rhs, bmnk, lhs_l, rhs_l)?; - Ok(Self::F32(dst)) - } - (CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => { - let dst = matmul(lhs, rhs, bmnk, lhs_l, rhs_l)?; - Ok(Self::F64(dst)) - } - _ => Err(Error::DTypeMismatchBinaryOp { - lhs: self.dtype(), - rhs: rhs.dtype(), - op: "matmul", - }), - } + MatMul(bmnk).map(self, lhs_l, rhs, rhs_l) } pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self { From 1328b5cb206dce9135c3d73dd6a5bd08f5e16d6a Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 28 Jun 2023 21:56:44 +0100 Subject: [PATCH 4/5] Factor some code out. --- candle-core/src/cpu_backend.rs | 111 ++++++++++----------------------- candle-core/src/dtype.rs | 19 ++++-- 2 files changed, 47 insertions(+), 83 deletions(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 7170e470..136eeaba 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -292,6 +292,35 @@ impl Map2 for MatMul { } } +fn divide_by_sum_over_dim( + s: &mut [T], + shape: &Shape, + dim: usize, +) -> Result<()> { + // [self] stores data in a contiguous way starting at offset 0. + let dims = shape.dims(); + let elem_per_slice = dims[dim]; + let prod_pre_dim = dims[..dim].iter().product(); + let prod_post_dim = dims[dim + 1..].iter().product(); + for pre_idx in 0..prod_pre_dim { + for post_idx in 0..prod_post_dim { + let mut sum = 0f64; + let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; + for _ in 0..elem_per_slice { + sum += s[idx].to_f64(); + idx += prod_post_dim + } + let sum = T::from_f64(sum); + let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; + for _ in 0..elem_per_slice { + s[idx] /= sum; + idx += prod_post_dim + } + } + } + Ok(()) +} + impl CpuStorage { pub fn dtype(&self) -> DType { match self { @@ -437,85 +466,13 @@ impl CpuStorage { pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> { // [self] stores data in a contiguous way starting at offset 0. - let dims = shape.dims(); - let elem_per_slice = dims[dim]; - let prod_pre_dim = dims[..dim].iter().product(); - let prod_post_dim = dims[dim + 1..].iter().product(); match self { - Self::BF16(storage) => { - for pre_idx in 0..prod_pre_dim { - for post_idx in 0..prod_post_dim { - let mut sum = 0f64; - let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; - for _ in 0..elem_per_slice { - sum += storage[idx].to_f64(); - idx += prod_post_dim - } - let sum = bf16::from_f64(sum); - let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; - for _ in 0..elem_per_slice { - storage[idx] /= sum; - idx += prod_post_dim - } - } - } - } - Self::F16(storage) => { - for pre_idx in 0..prod_pre_dim { - for post_idx in 0..prod_post_dim { - let mut sum = 0f64; - let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; - for _ in 0..elem_per_slice { - sum += storage[idx].to_f64(); - idx += prod_post_dim - } - let sum = f16::from_f64(sum); - let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; - for _ in 0..elem_per_slice { - storage[idx] /= sum; - idx += prod_post_dim - } - } - } - } - Self::F32(storage) => { - for pre_idx in 0..prod_pre_dim { - for post_idx in 0..prod_post_dim { - let mut sum = 0f64; - let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; - for _ in 0..elem_per_slice { - sum += storage[idx] as f64; - idx += prod_post_dim - } - let sum = sum as f32; - let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; - for _ in 0..elem_per_slice { - storage[idx] /= sum; - idx += prod_post_dim - } - } - } - } - Self::F64(storage) => { - for pre_idx in 0..prod_pre_dim { - for post_idx in 0..prod_post_dim { - let mut sum = 0f64; - let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; - for _ in 0..elem_per_slice { - sum += storage[idx]; - idx += prod_post_dim - } - let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; - for _ in 0..elem_per_slice { - storage[idx] /= sum; - idx += prod_post_dim - } - } - } - } - Self::U32(_) => {} + Self::BF16(s) => divide_by_sum_over_dim(s, shape, dim), + Self::F16(s) => divide_by_sum_over_dim(s, shape, dim), + Self::F32(s) => divide_by_sum_over_dim(s, shape, dim), + Self::F64(s) => divide_by_sum_over_dim(s, shape, dim), + Self::U32(_) => Ok(()), } - Ok(()) } pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result { diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index 87e3c9b4..89655324 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -35,6 +35,7 @@ pub trait WithDType: Sized + Copy { const DTYPE: DType; fn from_f64(v: f64) -> Self; + fn to_f64(self) -> f64; fn to_cpu_storage_owned(data: Vec) -> CpuStorage; fn to_cpu_storage(data: &[Self]) -> CpuStorage { @@ -46,7 +47,7 @@ pub trait WithDType: Sized + Copy { } macro_rules! with_dtype { - ($ty:ty, $dtype:ident, $from_f64:expr) => { + ($ty:ty, $dtype:ident, $from_f64:expr, $to_f64:expr) => { impl WithDType for $ty { const DTYPE: DType = DType::$dtype; @@ -54,6 +55,10 @@ macro_rules! with_dtype { $from_f64(v) } + fn to_f64(self) -> f64 { + $to_f64(self) + } + fn to_cpu_storage_owned(data: Vec) -> CpuStorage { CpuStorage::$dtype(data) } @@ -82,8 +87,10 @@ macro_rules! with_dtype { } }; } -with_dtype!(u32, U32, |v: f64| v as u32); -with_dtype!(half::f16, F16, half::f16::from_f64); -with_dtype!(half::bf16, BF16, half::bf16::from_f64); -with_dtype!(f32, F32, |v: f64| v as f32); -with_dtype!(f64, F64, |v: f64| v); +use half::{bf16, f16}; + +with_dtype!(u32, U32, |v: f64| v as u32, |v: u32| v as f64); +with_dtype!(f16, F16, f16::from_f64, f16::to_f64); +with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64); +with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64); +with_dtype!(f64, F64, |v: f64| v, |v: f64| v); From eaa3ce359e8d378e08ea147bd0291a75b4ec76d0 Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 28 Jun 2023 22:02:23 +0100 Subject: [PATCH 5/5] Cosmetic change. --- candle-core/src/cpu_backend.rs | 50 +++++++--------------------------- 1 file changed, 10 insertions(+), 40 deletions(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 136eeaba..a47d7c18 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -604,52 +604,22 @@ impl CpuStorage { pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self { let elem_count = shape.elem_count(); match dtype { - DType::U32 => { - let data = vec![1u32; elem_count]; - Self::U32(data) - } - DType::BF16 => { - let data = vec![bf16::ONE; elem_count]; - Self::BF16(data) - } - DType::F16 => { - let data = vec![f16::ONE; elem_count]; - Self::F16(data) - } - DType::F32 => { - let data = vec![1f32; elem_count]; - Self::F32(data) - } - DType::F64 => { - let data = vec![1f64; elem_count]; - Self::F64(data) - } + DType::U32 => Self::U32(vec![1u32; elem_count]), + DType::BF16 => Self::BF16(vec![bf16::ONE; elem_count]), + DType::F16 => Self::F16(vec![f16::ONE; elem_count]), + DType::F32 => Self::F32(vec![1f32; elem_count]), + DType::F64 => Self::F64(vec![1f64; elem_count]), } } pub(crate) fn zeros_impl(shape: &Shape, dtype: DType) -> Self { let elem_count = shape.elem_count(); match dtype { - DType::U32 => { - let data = vec![0u32; elem_count]; - Self::U32(data) - } - DType::BF16 => { - let data = vec![bf16::ZERO; elem_count]; - Self::BF16(data) - } - DType::F16 => { - let data = vec![f16::ZERO; elem_count]; - Self::F16(data) - } - DType::F32 => { - let data = vec![0f32; elem_count]; - Self::F32(data) - } - DType::F64 => { - let data = vec![0f64; elem_count]; - Self::F64(data) - } + DType::U32 => Self::U32(vec![0u32; elem_count]), + DType::BF16 => Self::BF16(vec![bf16::ZERO; elem_count]), + DType::F16 => Self::F16(vec![f16::ZERO; elem_count]), + DType::F32 => Self::F32(vec![0f32; elem_count]), + DType::F64 => Self::F64(vec![0f64; elem_count]), } } }