From 1328b5cb206dce9135c3d73dd6a5bd08f5e16d6a Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 28 Jun 2023 21:56:44 +0100 Subject: [PATCH] Factor some code out. --- candle-core/src/cpu_backend.rs | 111 ++++++++++----------------------- candle-core/src/dtype.rs | 19 ++++-- 2 files changed, 47 insertions(+), 83 deletions(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 7170e470..136eeaba 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -292,6 +292,35 @@ impl Map2 for MatMul { } } +fn divide_by_sum_over_dim( + s: &mut [T], + shape: &Shape, + dim: usize, +) -> Result<()> { + // [self] stores data in a contiguous way starting at offset 0. + let dims = shape.dims(); + let elem_per_slice = dims[dim]; + let prod_pre_dim = dims[..dim].iter().product(); + let prod_post_dim = dims[dim + 1..].iter().product(); + for pre_idx in 0..prod_pre_dim { + for post_idx in 0..prod_post_dim { + let mut sum = 0f64; + let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; + for _ in 0..elem_per_slice { + sum += s[idx].to_f64(); + idx += prod_post_dim + } + let sum = T::from_f64(sum); + let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; + for _ in 0..elem_per_slice { + s[idx] /= sum; + idx += prod_post_dim + } + } + } + Ok(()) +} + impl CpuStorage { pub fn dtype(&self) -> DType { match self { @@ -437,85 +466,13 @@ impl CpuStorage { pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> { // [self] stores data in a contiguous way starting at offset 0. - let dims = shape.dims(); - let elem_per_slice = dims[dim]; - let prod_pre_dim = dims[..dim].iter().product(); - let prod_post_dim = dims[dim + 1..].iter().product(); match self { - Self::BF16(storage) => { - for pre_idx in 0..prod_pre_dim { - for post_idx in 0..prod_post_dim { - let mut sum = 0f64; - let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; - for _ in 0..elem_per_slice { - sum += storage[idx].to_f64(); - idx += prod_post_dim - } - let sum = bf16::from_f64(sum); - let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; - for _ in 0..elem_per_slice { - storage[idx] /= sum; - idx += prod_post_dim - } - } - } - } - Self::F16(storage) => { - for pre_idx in 0..prod_pre_dim { - for post_idx in 0..prod_post_dim { - let mut sum = 0f64; - let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; - for _ in 0..elem_per_slice { - sum += storage[idx].to_f64(); - idx += prod_post_dim - } - let sum = f16::from_f64(sum); - let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; - for _ in 0..elem_per_slice { - storage[idx] /= sum; - idx += prod_post_dim - } - } - } - } - Self::F32(storage) => { - for pre_idx in 0..prod_pre_dim { - for post_idx in 0..prod_post_dim { - let mut sum = 0f64; - let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; - for _ in 0..elem_per_slice { - sum += storage[idx] as f64; - idx += prod_post_dim - } - let sum = sum as f32; - let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; - for _ in 0..elem_per_slice { - storage[idx] /= sum; - idx += prod_post_dim - } - } - } - } - Self::F64(storage) => { - for pre_idx in 0..prod_pre_dim { - for post_idx in 0..prod_post_dim { - let mut sum = 0f64; - let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; - for _ in 0..elem_per_slice { - sum += storage[idx]; - idx += prod_post_dim - } - let mut idx = pre_idx * prod_post_dim * elem_per_slice + post_idx; - for _ in 0..elem_per_slice { - storage[idx] /= sum; - idx += prod_post_dim - } - } - } - } - Self::U32(_) => {} + Self::BF16(s) => divide_by_sum_over_dim(s, shape, dim), + Self::F16(s) => divide_by_sum_over_dim(s, shape, dim), + Self::F32(s) => divide_by_sum_over_dim(s, shape, dim), + Self::F64(s) => divide_by_sum_over_dim(s, shape, dim), + Self::U32(_) => Ok(()), } - Ok(()) } pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result { diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index 87e3c9b4..89655324 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -35,6 +35,7 @@ pub trait WithDType: Sized + Copy { const DTYPE: DType; fn from_f64(v: f64) -> Self; + fn to_f64(self) -> f64; fn to_cpu_storage_owned(data: Vec) -> CpuStorage; fn to_cpu_storage(data: &[Self]) -> CpuStorage { @@ -46,7 +47,7 @@ pub trait WithDType: Sized + Copy { } macro_rules! with_dtype { - ($ty:ty, $dtype:ident, $from_f64:expr) => { + ($ty:ty, $dtype:ident, $from_f64:expr, $to_f64:expr) => { impl WithDType for $ty { const DTYPE: DType = DType::$dtype; @@ -54,6 +55,10 @@ macro_rules! with_dtype { $from_f64(v) } + fn to_f64(self) -> f64 { + $to_f64(self) + } + fn to_cpu_storage_owned(data: Vec) -> CpuStorage { CpuStorage::$dtype(data) } @@ -82,8 +87,10 @@ macro_rules! with_dtype { } }; } -with_dtype!(u32, U32, |v: f64| v as u32); -with_dtype!(half::f16, F16, half::f16::from_f64); -with_dtype!(half::bf16, BF16, half::bf16::from_f64); -with_dtype!(f32, F32, |v: f64| v as f32); -with_dtype!(f64, F64, |v: f64| v); +use half::{bf16, f16}; + +with_dtype!(u32, U32, |v: f64| v as u32, |v: u32| v as f64); +with_dtype!(f16, F16, f16::from_f64, f16::to_f64); +with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64); +with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64); +with_dtype!(f64, F64, |v: f64| v, |v: f64| v);