From 54a6c40f2715d4ba6018d047bd3ec678fc0f3664 Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 28 Jun 2023 14:00:49 +0100 Subject: [PATCH] Propagate the changes on the cpu backend. --- candle-core/src/cpu_backend.rs | 159 ++++++++++++++++----------------- candle-core/src/layout.rs | 11 +++ 2 files changed, 89 insertions(+), 81 deletions(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 8a0cdb31..4f63ea98 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -24,25 +24,25 @@ fn wcond( f: &[T], layout_f: &Layout, ) -> Vec { - if layout.is_contiguous() && layout_t.is_contiguous() && layout_f.is_contiguous() { - let elem_count = layout.shape().elem_count(); - let offset = layout.start_offset(); - let offset_t = layout_t.start_offset(); - let offset_f = layout_f.start_offset(); - let pred = &pred[offset..offset + elem_count]; - let t = &t[offset_t..offset_t + elem_count]; - let f = &f[offset_f..offset_f + elem_count]; - pred.iter() - .zip(t.iter().zip(f.iter())) - .map(|(&p, (&t, &f))| if p > 0 { t } else { f }) - .collect::>() - } else { - let it_p = StridedIndex::new(layout); - let it_t = StridedIndex::new(layout_t); - let it_f = StridedIndex::new(layout_f); - it_p.zip(it_t.zip(it_f)) + match ( + layout.contiguous_offsets(), + layout_t.contiguous_offsets(), + layout_f.contiguous_offsets(), + ) { + (Some((o1, o2)), Some((o_t1, o_t2)), Some((o_f1, o_f2))) => { + let pred = &pred[o1..o2]; + let t = &t[o_t1..o_t2]; + let f = &f[o_f1..o_f2]; + pred.iter() + .zip(t.iter().zip(f.iter())) + .map(|(&p, (&t, &f))| if p > 0 { t } else { f }) + .collect::>() + } + _ => layout + .strided_index() + .zip(layout_t.strided_index().zip(layout_f.strided_index())) .map(|(i_p, (i_t, i_f))| if pred[i_p] > 0 { t[i_t] } else { f[i_f] }) - .collect::>() + .collect::>(), } } @@ -62,42 +62,38 @@ macro_rules! map1 { fn sum_impl1( src: &[T], dst_shape: &Shape, - src_dims: &[usize], - stride: &[usize], + src_layout: &Layout, to_dst_index: impl Fn(usize) -> usize, ) -> Result> { let mut dst = vec![T::zero(); dst_shape.elem_count()]; - for (unstr_index, src_index) in StridedIndex::new(src_dims, stride).enumerate() { + for (unstr_index, src_index) in src_layout.strided_index().enumerate() { dst[to_dst_index(unstr_index)] += src[src_index]; } Ok(dst) } fn unary_map U>(vs: &[T], layout: &Layout, mut f: F) -> Vec { - if shape.is_contiguous(stride) { - vs[..shape.elem_count()].iter().map(|&v| f(v)).collect() - } else { - StridedIndex::new(shape.dims(), stride) - .map(|i| f(vs[i])) - .collect() + match layout.contiguous_offsets() { + Some((o1, o2)) => vs[o1..o2].iter().map(|&v| f(v)).collect(), + None => layout.strided_index().map(|i| f(vs[i])).collect(), } } // This function maps over two strided index sequences. fn binary_map T>( shape: &Shape, - lhs_stride: &[usize], - rhs_stride: &[usize], + lhs_layout: &Layout, + rhs_layout: &Layout, lhs: &[T], rhs: &[T], mut f: F, ) -> Vec { let dims = shape.dims(); - if shape.is_contiguous(lhs_stride) && shape.is_contiguous(rhs_stride) { + if lhs_layout.is_contiguous() && rhs_layout.is_contiguous() { (0..shape.elem_count()).map(|i| f(lhs[i], rhs[i])).collect() } else { - let lhs_index = StridedIndex::new(dims, lhs_stride); - let rhs_index = StridedIndex::new(dims, rhs_stride); + let lhs_index = lhs_layout.strided_index(); + let rhs_index = rhs_layout.strided_index(); lhs_index .zip(rhs_index) .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) @@ -114,7 +110,7 @@ fn take_impl1( ) -> Result> { // TODO: Optimize for the case where ids are contiguous. let mut values = Vec::with_capacity(layout.shape().elem_count() * hidden_size); - for index in StridedIndex::new(layout) { + for index in layout.strided_index() { let index = ids[index].try_into()?; if index >= vocab_size { return Err(Error::InvalidIndex { @@ -135,18 +131,19 @@ fn copy_strided_src_( dst_offset: usize, src_l: &Layout, ) { - let src = &src[src_l.start_offset()..]; - if src_l.is_contiguous() { - let elem_to_copy = (dst.len() - dst_offset).min(src.len()); - dst[dst_offset..dst_offset + elem_to_copy].copy_from_slice(&src[..elem_to_copy]) - } else { - let src_indexes = StridedIndex::new(src_l); - for (dst_index, src_index) in src_indexes.enumerate() { - let dst_index = dst_index + dst_offset; - if dst_index >= dst.len() { - break; + match src_l.contiguous_offsets() { + Some((o_dst1, o_dst2)) => { + let elem_to_copy = (dst.len() - dst_offset).min(o_dst2 - o_dst1); + dst[dst_offset..dst_offset + elem_to_copy].copy_from_slice(&src[o_dst1..o_dst2]) + } + None => { + for (dst_index, src_index) in src_l.strided_index().enumerate() { + let dst_index = dst_index + dst_offset; + if dst_index >= dst.len() { + break; + } + dst[dst_index] = src[src_index] } - dst[dst_index] = src[src_index] } } } @@ -235,114 +232,114 @@ impl CpuStorage { D::cpu_storage_as_mut_slice(self) } - pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result { + pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { // TODO: find a way around the quadratic number of cases below. match (self, dtype) { (Self::U32(storage), DType::BF16) => { - let data = unary_map(storage, shape, stride, |v| bf16::from_f32(v as f32)); + let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); Ok(Self::BF16(data)) } (Self::BF16(storage), DType::BF16) => { - let data = unary_map(storage, shape, stride, |v| v); + let data = unary_map(storage, layout, |v| v); Ok(Self::BF16(data)) } (Self::F16(storage), DType::BF16) => { - let data = unary_map(storage, shape, stride, |v| bf16::from_f32(v.to_f32())); + let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32())); Ok(Self::BF16(data)) } (Self::F32(storage), DType::BF16) => { - let data = unary_map(storage, shape, stride, bf16::from_f32); + let data = unary_map(storage, layout, bf16::from_f32); Ok(Self::BF16(data)) } (Self::F64(storage), DType::BF16) => { - let data = unary_map(storage, shape, stride, bf16::from_f64); + let data = unary_map(storage, layout, bf16::from_f64); Ok(Self::BF16(data)) } (Self::U32(storage), DType::F16) => { - let data = unary_map(storage, shape, stride, |v| f16::from_f32(v as f32)); + let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); Ok(Self::F16(data)) } (Self::BF16(storage), DType::F16) => { - let data = unary_map(storage, shape, stride, |v| f16::from_f32(v.to_f32())); + let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32())); Ok(Self::F16(data)) } (Self::F16(storage), DType::F16) => { - let data = unary_map(storage, shape, stride, |v| v); + let data = unary_map(storage, layout, |v| v); Ok(Self::F16(data)) } (Self::F32(storage), DType::F16) => { - let data = unary_map(storage, shape, stride, f16::from_f32); + let data = unary_map(storage, layout, f16::from_f32); Ok(Self::F16(data)) } (Self::F64(storage), DType::F16) => { - let data = unary_map(storage, shape, stride, f16::from_f64); + let data = unary_map(storage, layout, f16::from_f64); Ok(Self::F16(data)) } (Self::U32(storage), DType::F32) => { - let data = unary_map(storage, shape, stride, |v| v as f32); + let data = unary_map(storage, layout, |v| v as f32); Ok(Self::F32(data)) } (Self::BF16(storage), DType::F32) => { - let data = unary_map(storage, shape, stride, |v| v.to_f32()); + let data = unary_map(storage, layout, |v| v.to_f32()); Ok(Self::F32(data)) } (Self::F16(storage), DType::F32) => { - let data = unary_map(storage, shape, stride, |v| v.to_f32()); + let data = unary_map(storage, layout, |v| v.to_f32()); Ok(Self::F32(data)) } (Self::F32(storage), DType::F32) => { - let data = unary_map(storage, shape, stride, |v| v); + let data = unary_map(storage, layout, |v| v); Ok(Self::F32(data)) } (Self::F64(storage), DType::F32) => { - let data = unary_map(storage, shape, stride, |v| v as f32); + let data = unary_map(storage, layout, |v| v as f32); Ok(Self::F32(data)) } (Self::U32(storage), DType::U32) => { - let data = unary_map(storage, shape, stride, |v| v); + let data = unary_map(storage, layout, |v| v); Ok(Self::U32(data)) } (Self::BF16(storage), DType::U32) => { - let data = unary_map(storage, shape, stride, |v| v.to_f32() as u32); + let data = unary_map(storage, layout, |v| v.to_f32() as u32); Ok(Self::U32(data)) } (Self::F16(storage), DType::U32) => { - let data = unary_map(storage, shape, stride, |v| v.to_f32() as u32); + let data = unary_map(storage, layout, |v| v.to_f32() as u32); Ok(Self::U32(data)) } (Self::F32(storage), DType::U32) => { - let data = unary_map(storage, shape, stride, |v| v as u32); + let data = unary_map(storage, layout, |v| v as u32); Ok(Self::U32(data)) } (Self::F64(storage), DType::U32) => { - let data = unary_map(storage, shape, stride, |v| v as u32); + let data = unary_map(storage, layout, |v| v as u32); Ok(Self::U32(data)) } (Self::U32(storage), DType::F64) => { - let data = unary_map(storage, shape, stride, |v| v as f64); + let data = unary_map(storage, layout, |v| v as f64); Ok(Self::F64(data)) } (Self::BF16(storage), DType::F64) => { - let data = unary_map(storage, shape, stride, |v| v.to_f64()); + let data = unary_map(storage, layout, |v| v.to_f64()); Ok(Self::F64(data)) } (Self::F16(storage), DType::F64) => { - let data = unary_map(storage, shape, stride, |v| v.to_f64()); + let data = unary_map(storage, layout, |v| v.to_f64()); Ok(Self::F64(data)) } (Self::F32(storage), DType::F64) => { - let data = unary_map(storage, shape, stride, |v| v as f64); + let data = unary_map(storage, layout, |v| v as f64); Ok(Self::F64(data)) } (Self::F64(storage), DType::F64) => { - let data = unary_map(storage, shape, stride, |v| v); + let data = unary_map(storage, layout, |v| v); Ok(Self::F64(data)) } } } - pub(crate) fn sum(&self, shape: &Shape, stride: &[usize], sum_dims: &[usize]) -> Result { - let src_dims = shape.dims(); + pub(crate) fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result { + let src_dims = layout.dims(); let mut dst_dims = src_dims.to_vec(); for &sum_dim in sum_dims.iter() { dst_dims[sum_dim] = 1; @@ -368,7 +365,7 @@ impl CpuStorage { dst_index }; // TODO: Maybe provide an implementation with higher precision accumulators? - map1!(self, sum_impl1, &dst_shape, src_dims, stride, to_dst_index) + map1!(self, sum_impl1, &dst_shape, layout, to_dst_index) } pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> { @@ -516,28 +513,28 @@ impl CpuStorage { &self, rhs: &Self, shape: &Shape, - lhs_stride: &[usize], - rhs_stride: &[usize], + lhs_layout: &Layout, + rhs_layout: &Layout, ) -> Result { match (self, rhs) { (Self::BF16(lhs), Self::BF16(rhs)) => { - let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::bf16); + let data = binary_map(shape, lhs_layout, rhs_layout, lhs, rhs, B::bf16); Ok(Self::BF16(data)) } (Self::F16(lhs), Self::F16(rhs)) => { - let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f16); + let data = binary_map(shape, lhs_layout, rhs_layout, lhs, rhs, B::f16); Ok(Self::F16(data)) } (Self::F32(lhs), Self::F32(rhs)) => { - let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f32); + let data = binary_map(shape, lhs_layout, rhs_layout, lhs, rhs, B::f32); Ok(Self::F32(data)) } (Self::F64(lhs), Self::F64(rhs)) => { - let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f64); + let data = binary_map(shape, lhs_layout, rhs_layout, lhs, rhs, B::f64); Ok(Self::F64(data)) } (Self::U32(lhs), Self::U32(rhs)) => { - let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::u32); + let data = binary_map(shape, lhs_layout, rhs_layout, lhs, rhs, B::u32); Ok(Self::U32(data)) } _ => { @@ -555,7 +552,7 @@ impl CpuStorage { &self, dst: &mut Self, dst_offset: usize, - src_l: Layout, + src_l: &Layout, ) -> Result<()> { match (self, dst) { (Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), diff --git a/candle-core/src/layout.rs b/candle-core/src/layout.rs index 08d34c4b..6ba0d79a 100644 --- a/candle-core/src/layout.rs +++ b/candle-core/src/layout.rs @@ -39,6 +39,17 @@ impl Layout { self.start_offset } + /// Returns the appropriate start and stop offset if the data is stored in a C + /// contiguous (aka row major) way. + pub fn contiguous_offsets(&self) -> Option<(usize, usize)> { + if self.is_contiguous() { + let start_o = self.start_offset; + Some((start_o, start_o + self.shape.elem_count())) + } else { + None + } + } + /// Returns true if the data is stored in a C contiguous (aka row major) way. pub fn is_contiguous(&self) -> bool { self.shape.is_contiguous(&self.stride)