mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Add Shape try into (#189)
* Add the TryInto trait for shapes. * Use the vectorized operations in block mode too.
This commit is contained in:
@ -168,24 +168,30 @@ fn unary_map_vec<T: Copy, U: Copy, F: FnMut(T) -> U, FV: FnMut(&[T], &mut [U])>(
|
|||||||
block_start_index,
|
block_start_index,
|
||||||
block_len,
|
block_len,
|
||||||
} => {
|
} => {
|
||||||
let mut result = vec![];
|
let el_count = layout.shape().elem_count();
|
||||||
result.reserve(layout.shape().elem_count());
|
|
||||||
// Specialize the case where block_len is one to avoid the second loop.
|
// Specialize the case where block_len is one to avoid the second loop.
|
||||||
if block_len == 1 {
|
if block_len == 1 {
|
||||||
|
let mut result = Vec::with_capacity(el_count);
|
||||||
for index in block_start_index {
|
for index in block_start_index {
|
||||||
let v = unsafe { vs.get_unchecked(index) };
|
let v = unsafe { vs.get_unchecked(index) };
|
||||||
result.push(f(*v))
|
result.push(f(*v))
|
||||||
}
|
}
|
||||||
|
result
|
||||||
} else {
|
} else {
|
||||||
// TODO: Use f_vec here.
|
let mut ys: Vec<U> = Vec::with_capacity(el_count);
|
||||||
for index in block_start_index {
|
let ys_to_set = ys.spare_capacity_mut();
|
||||||
for offset in 0..block_len {
|
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) };
|
||||||
let v = unsafe { vs.get_unchecked(index + offset) };
|
let mut dst_index = 0;
|
||||||
result.push(f(*v))
|
for src_index in block_start_index {
|
||||||
}
|
let vs = &vs[src_index..src_index + block_len];
|
||||||
|
let ys = &mut ys_to_set[dst_index..dst_index + block_len];
|
||||||
|
f_vec(vs, ys);
|
||||||
|
dst_index += block_len;
|
||||||
}
|
}
|
||||||
|
// SAFETY: values are all set by f_vec.
|
||||||
|
unsafe { ys.set_len(el_count) };
|
||||||
|
ys
|
||||||
}
|
}
|
||||||
result
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -73,16 +73,24 @@ impl From<Vec<usize>> for Shape {
|
|||||||
|
|
||||||
macro_rules! extract_dims {
|
macro_rules! extract_dims {
|
||||||
($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => {
|
($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => {
|
||||||
pub fn $fn_name(&self) -> Result<$out_type> {
|
impl Shape {
|
||||||
if self.0.len() != $cnt {
|
pub fn $fn_name(&self) -> Result<$out_type> {
|
||||||
Err(Error::UnexpectedNumberOfDims {
|
if self.0.len() != $cnt {
|
||||||
expected: $cnt,
|
Err(Error::UnexpectedNumberOfDims {
|
||||||
got: self.0.len(),
|
expected: $cnt,
|
||||||
shape: self.clone(),
|
got: self.0.len(),
|
||||||
|
shape: self.clone(),
|
||||||
|
}
|
||||||
|
.bt())
|
||||||
|
} else {
|
||||||
|
Ok($dims(&self.0))
|
||||||
}
|
}
|
||||||
.bt())
|
}
|
||||||
} else {
|
}
|
||||||
Ok($dims(&self.0))
|
impl std::convert::TryInto<$out_type> for Shape {
|
||||||
|
type Error = crate::Error;
|
||||||
|
fn try_into(self) -> std::result::Result<$out_type, Self::Error> {
|
||||||
|
self.$fn_name()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -109,28 +117,6 @@ impl Shape {
|
|||||||
self.0.iter().product()
|
self.0.iter().product()
|
||||||
}
|
}
|
||||||
|
|
||||||
extract_dims!(r0, 0, |_: &Vec<usize>| (), ());
|
|
||||||
extract_dims!(r1, 1, |d: &[usize]| d[0], usize);
|
|
||||||
extract_dims!(r2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
|
|
||||||
extract_dims!(
|
|
||||||
r3,
|
|
||||||
3,
|
|
||||||
|d: &[usize]| (d[0], d[1], d[2]),
|
|
||||||
(usize, usize, usize)
|
|
||||||
);
|
|
||||||
extract_dims!(
|
|
||||||
r4,
|
|
||||||
4,
|
|
||||||
|d: &[usize]| (d[0], d[1], d[2], d[3]),
|
|
||||||
(usize, usize, usize, usize)
|
|
||||||
);
|
|
||||||
extract_dims!(
|
|
||||||
r5,
|
|
||||||
5,
|
|
||||||
|d: &[usize]| (d[0], d[1], d[2], d[3], d[4]),
|
|
||||||
(usize, usize, usize, usize, usize)
|
|
||||||
);
|
|
||||||
|
|
||||||
/// The strides given in number of elements for a contiguous n-dimensional
|
/// The strides given in number of elements for a contiguous n-dimensional
|
||||||
/// arrays using this shape.
|
/// arrays using this shape.
|
||||||
pub(crate) fn stride_contiguous(&self) -> Vec<usize> {
|
pub(crate) fn stride_contiguous(&self) -> Vec<usize> {
|
||||||
@ -342,6 +328,28 @@ impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
extract_dims!(r0, 0, |_: &Vec<usize>| (), ());
|
||||||
|
extract_dims!(r1, 1, |d: &[usize]| d[0], usize);
|
||||||
|
extract_dims!(r2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
|
||||||
|
extract_dims!(
|
||||||
|
r3,
|
||||||
|
3,
|
||||||
|
|d: &[usize]| (d[0], d[1], d[2]),
|
||||||
|
(usize, usize, usize)
|
||||||
|
);
|
||||||
|
extract_dims!(
|
||||||
|
r4,
|
||||||
|
4,
|
||||||
|
|d: &[usize]| (d[0], d[1], d[2], d[3]),
|
||||||
|
(usize, usize, usize, usize)
|
||||||
|
);
|
||||||
|
extract_dims!(
|
||||||
|
r5,
|
||||||
|
5,
|
||||||
|
|d: &[usize]| (d[0], d[1], d[2], d[3], d[4]),
|
||||||
|
(usize, usize, usize, usize, usize)
|
||||||
|
);
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
@ -315,7 +315,6 @@ impl BertSelfAttention {
|
|||||||
new_x_shape.pop();
|
new_x_shape.pop();
|
||||||
new_x_shape.push(self.num_attention_heads);
|
new_x_shape.push(self.num_attention_heads);
|
||||||
new_x_shape.push(self.attention_head_size);
|
new_x_shape.push(self.attention_head_size);
|
||||||
// Be cautious about the transposition if adding a batch dim!
|
|
||||||
let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
|
let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
|
||||||
xs.contiguous()
|
xs.contiguous()
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user