Add Shape try into (#189)

* Add the TryInto trait for shapes.

* Use the vectorized operations in block mode too.
This commit is contained in:
Laurent Mazare
2023-07-18 10:52:16 +01:00
committed by GitHub
parent d6313d2447
commit b706f32839
3 changed files with 54 additions and 41 deletions

View File

@ -168,24 +168,30 @@ fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
block_start_index,
block_len,
} => {
let mut result = vec![];
result.reserve(layout.shape().elem_count());
let el_count = layout.shape().elem_count();
// Specialize the case where block_len is one to avoid the second loop.
if block_len == 1 {
let mut result = Vec::with_capacity(el_count);
for index in block_start_index {
let v = unsafe { vs.get_unchecked(index) };
result.push(f(*v))
}
} else {
// TODO: Use f_vec here.
for index in block_start_index {
for offset in 0..block_len {
let v = unsafe { vs.get_unchecked(index + offset) };
result.push(f(*v))
}
}
}
result
} else {
let mut ys: Vec<U> = Vec::with_capacity(el_count);
let ys_to_set = ys.spare_capacity_mut();
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
let mut dst_index = 0;
for src_index in block_start_index {
let vs = &vs[src_index..src_index + block_len];
let ys = &mut ys_to_set[dst_index..dst_index + block_len];
f_vec(vs, ys);
dst_index += block_len;
}
// SAFETY: values are all set by f_vec.
unsafe { ys.set_len(el_count) };
ys
}
}
}
}

View File

@ -73,6 +73,7 @@ impl From<Vec<usize>> for Shape {
macro_rules! extract_dims {
($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => {
impl Shape {
pub fn $fn_name(&self) -> Result<$out_type> {
if self.0.len() != $cnt {
Err(Error::UnexpectedNumberOfDims {
@ -85,6 +86,13 @@ macro_rules! extract_dims {
Ok($dims(&self.0))
}
}
}
impl std::convert::TryInto<$out_type> for Shape {
type Error = crate::Error;
fn try_into(self) -> std::result::Result<$out_type, Self::Error> {
self.$fn_name()
}
}
};
}
@ -109,28 +117,6 @@ impl Shape {
self.0.iter().product()
}
extract_dims!(r0, 0, |_: &Vec<usize>| (), ());
extract_dims!(r1, 1, |d: &[usize]| d[0], usize);
extract_dims!(r2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
extract_dims!(
r3,
3,
|d: &[usize]| (d[0], d[1], d[2]),
(usize, usize, usize)
);
extract_dims!(
r4,
4,
|d: &[usize]| (d[0], d[1], d[2], d[3]),
(usize, usize, usize, usize)
);
extract_dims!(
r5,
5,
|d: &[usize]| (d[0], d[1], d[2], d[3], d[4]),
(usize, usize, usize, usize, usize)
);
/// The strides given in number of elements for a contiguous n-dimensional
/// arrays using this shape.
pub(crate) fn stride_contiguous(&self) -> Vec<usize> {
@ -342,6 +328,28 @@ impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
}
}
extract_dims!(r0, 0, |_: &Vec<usize>| (), ());
extract_dims!(r1, 1, |d: &[usize]| d[0], usize);
extract_dims!(r2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
extract_dims!(
r3,
3,
|d: &[usize]| (d[0], d[1], d[2]),
(usize, usize, usize)
);
extract_dims!(
r4,
4,
|d: &[usize]| (d[0], d[1], d[2], d[3]),
(usize, usize, usize, usize)
);
extract_dims!(
r5,
5,
|d: &[usize]| (d[0], d[1], d[2], d[3], d[4]),
(usize, usize, usize, usize, usize)
);
#[cfg(test)]
mod tests {
use super::*;

View File

@ -315,7 +315,6 @@ impl BertSelfAttention {
new_x_shape.pop();
new_x_shape.push(self.num_attention_heads);
new_x_shape.push(self.attention_head_size);
// Be cautious about the transposition if adding a batch dim!
let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
xs.contiguous()
}