mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add Shape try into (#189)
* Add the TryInto trait for shapes. * Use the vectorized operations in block mode too.
This commit is contained in:
@ -168,24 +168,30 @@ fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
|
||||
block_start_index,
|
||||
block_len,
|
||||
} => {
|
||||
let mut result = vec![];
|
||||
result.reserve(layout.shape().elem_count());
|
||||
let el_count = layout.shape().elem_count();
|
||||
// Specialize the case where block_len is one to avoid the second loop.
|
||||
if block_len == 1 {
|
||||
let mut result = Vec::with_capacity(el_count);
|
||||
for index in block_start_index {
|
||||
let v = unsafe { vs.get_unchecked(index) };
|
||||
result.push(f(*v))
|
||||
}
|
||||
} else {
|
||||
// TODO: Use f_vec here.
|
||||
for index in block_start_index {
|
||||
for offset in 0..block_len {
|
||||
let v = unsafe { vs.get_unchecked(index + offset) };
|
||||
result.push(f(*v))
|
||||
}
|
||||
}
|
||||
}
|
||||
result
|
||||
} else {
|
||||
let mut ys: Vec<U> = Vec::with_capacity(el_count);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
||||
let mut dst_index = 0;
|
||||
for src_index in block_start_index {
|
||||
let vs = &vs[src_index..src_index + block_len];
|
||||
let ys = &mut ys_to_set[dst_index..dst_index + block_len];
|
||||
f_vec(vs, ys);
|
||||
dst_index += block_len;
|
||||
}
|
||||
// SAFETY: values are all set by f_vec.
|
||||
unsafe { ys.set_len(el_count) };
|
||||
ys
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -73,6 +73,7 @@ impl From<Vec<usize>> for Shape {
|
||||
|
||||
macro_rules! extract_dims {
|
||||
($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => {
|
||||
impl Shape {
|
||||
pub fn $fn_name(&self) -> Result<$out_type> {
|
||||
if self.0.len() != $cnt {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
@ -85,6 +86,13 @@ macro_rules! extract_dims {
|
||||
Ok($dims(&self.0))
|
||||
}
|
||||
}
|
||||
}
|
||||
impl std::convert::TryInto<$out_type> for Shape {
|
||||
type Error = crate::Error;
|
||||
fn try_into(self) -> std::result::Result<$out_type, Self::Error> {
|
||||
self.$fn_name()
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@ -109,28 +117,6 @@ impl Shape {
|
||||
self.0.iter().product()
|
||||
}
|
||||
|
||||
extract_dims!(r0, 0, |_: &Vec<usize>| (), ());
|
||||
extract_dims!(r1, 1, |d: &[usize]| d[0], usize);
|
||||
extract_dims!(r2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
|
||||
extract_dims!(
|
||||
r3,
|
||||
3,
|
||||
|d: &[usize]| (d[0], d[1], d[2]),
|
||||
(usize, usize, usize)
|
||||
);
|
||||
extract_dims!(
|
||||
r4,
|
||||
4,
|
||||
|d: &[usize]| (d[0], d[1], d[2], d[3]),
|
||||
(usize, usize, usize, usize)
|
||||
);
|
||||
extract_dims!(
|
||||
r5,
|
||||
5,
|
||||
|d: &[usize]| (d[0], d[1], d[2], d[3], d[4]),
|
||||
(usize, usize, usize, usize, usize)
|
||||
);
|
||||
|
||||
/// The strides given in number of elements for a contiguous n-dimensional
|
||||
/// arrays using this shape.
|
||||
pub(crate) fn stride_contiguous(&self) -> Vec<usize> {
|
||||
@ -342,6 +328,28 @@ impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
|
||||
}
|
||||
}
|
||||
|
||||
extract_dims!(r0, 0, |_: &Vec<usize>| (), ());
|
||||
extract_dims!(r1, 1, |d: &[usize]| d[0], usize);
|
||||
extract_dims!(r2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
|
||||
extract_dims!(
|
||||
r3,
|
||||
3,
|
||||
|d: &[usize]| (d[0], d[1], d[2]),
|
||||
(usize, usize, usize)
|
||||
);
|
||||
extract_dims!(
|
||||
r4,
|
||||
4,
|
||||
|d: &[usize]| (d[0], d[1], d[2], d[3]),
|
||||
(usize, usize, usize, usize)
|
||||
);
|
||||
extract_dims!(
|
||||
r5,
|
||||
5,
|
||||
|d: &[usize]| (d[0], d[1], d[2], d[3], d[4]),
|
||||
(usize, usize, usize, usize, usize)
|
||||
);
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
@ -315,7 +315,6 @@ impl BertSelfAttention {
|
||||
new_x_shape.pop();
|
||||
new_x_shape.push(self.num_attention_heads);
|
||||
new_x_shape.push(self.attention_head_size);
|
||||
// Be cautious about the transposition if adding a batch dim!
|
||||
let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
|
||||
xs.contiguous()
|
||||
}
|
||||
|
Reference in New Issue
Block a user