From b706f32839c13aa68686f800dd2102d07770fcbd Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 18 Jul 2023 10:52:16 +0100 Subject: [PATCH] Add Shape try into (#189) * Add the TryInto trait for shapes. * Use the vectorized operations in block mode too. --- candle-core/src/cpu_backend.rs | 24 +++++---- candle-core/src/shape.rs | 70 ++++++++++++++------------ candle-examples/examples/bert/model.rs | 1 - 3 files changed, 54 insertions(+), 41 deletions(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 91ccd972..00bd3033 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -168,24 +168,30 @@ fn unary_map_vec U, FV: FnMut(&[T], &mut [U])>( block_start_index, block_len, } => { - let mut result = vec![]; - result.reserve(layout.shape().elem_count()); + let el_count = layout.shape().elem_count(); // Specialize the case where block_len is one to avoid the second loop. if block_len == 1 { + let mut result = Vec::with_capacity(el_count); for index in block_start_index { let v = unsafe { vs.get_unchecked(index) }; result.push(f(*v)) } + result } else { - // TODO: Use f_vec here. - for index in block_start_index { - for offset in 0..block_len { - let v = unsafe { vs.get_unchecked(index + offset) }; - result.push(f(*v)) - } + let mut ys: Vec = Vec::with_capacity(el_count); + let ys_to_set = ys.spare_capacity_mut(); + let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) }; + let mut dst_index = 0; + for src_index in block_start_index { + let vs = &vs[src_index..src_index + block_len]; + let ys = &mut ys_to_set[dst_index..dst_index + block_len]; + f_vec(vs, ys); + dst_index += block_len; } + // SAFETY: values are all set by f_vec. + unsafe { ys.set_len(el_count) }; + ys } - result } } } diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index d3f8db01..982f9db0 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -73,16 +73,24 @@ impl From> for Shape { macro_rules! extract_dims { ($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => { - pub fn $fn_name(&self) -> Result<$out_type> { - if self.0.len() != $cnt { - Err(Error::UnexpectedNumberOfDims { - expected: $cnt, - got: self.0.len(), - shape: self.clone(), + impl Shape { + pub fn $fn_name(&self) -> Result<$out_type> { + if self.0.len() != $cnt { + Err(Error::UnexpectedNumberOfDims { + expected: $cnt, + got: self.0.len(), + shape: self.clone(), + } + .bt()) + } else { + Ok($dims(&self.0)) } - .bt()) - } else { - Ok($dims(&self.0)) + } + } + impl std::convert::TryInto<$out_type> for Shape { + type Error = crate::Error; + fn try_into(self) -> std::result::Result<$out_type, Self::Error> { + self.$fn_name() } } }; @@ -109,28 +117,6 @@ impl Shape { self.0.iter().product() } - extract_dims!(r0, 0, |_: &Vec| (), ()); - extract_dims!(r1, 1, |d: &[usize]| d[0], usize); - extract_dims!(r2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize)); - extract_dims!( - r3, - 3, - |d: &[usize]| (d[0], d[1], d[2]), - (usize, usize, usize) - ); - extract_dims!( - r4, - 4, - |d: &[usize]| (d[0], d[1], d[2], d[3]), - (usize, usize, usize, usize) - ); - extract_dims!( - r5, - 5, - |d: &[usize]| (d[0], d[1], d[2], d[3], d[4]), - (usize, usize, usize, usize, usize) - ); - /// The strides given in number of elements for a contiguous n-dimensional /// arrays using this shape. pub(crate) fn stride_contiguous(&self) -> Vec { @@ -342,6 +328,28 @@ impl Dims for (D1, D2, D3) { } } +extract_dims!(r0, 0, |_: &Vec| (), ()); +extract_dims!(r1, 1, |d: &[usize]| d[0], usize); +extract_dims!(r2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize)); +extract_dims!( + r3, + 3, + |d: &[usize]| (d[0], d[1], d[2]), + (usize, usize, usize) +); +extract_dims!( + r4, + 4, + |d: &[usize]| (d[0], d[1], d[2], d[3]), + (usize, usize, usize, usize) +); +extract_dims!( + r5, + 5, + |d: &[usize]| (d[0], d[1], d[2], d[3], d[4]), + (usize, usize, usize, usize, usize) +); + #[cfg(test)] mod tests { use super::*; diff --git a/candle-examples/examples/bert/model.rs b/candle-examples/examples/bert/model.rs index 7323db4b..fa0e8c76 100644 --- a/candle-examples/examples/bert/model.rs +++ b/candle-examples/examples/bert/model.rs @@ -315,7 +315,6 @@ impl BertSelfAttention { new_x_shape.pop(); new_x_shape.push(self.num_attention_heads); new_x_shape.push(self.attention_head_size); - // Be cautious about the transposition if adding a batch dim! let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?; xs.contiguous() }