Propagate the changes on the cpu backend.

This commit is contained in:
laurent
2023-06-28 14:00:49 +01:00
parent 303b853098
commit 54a6c40f27
2 changed files with 89 additions and 81 deletions

View File

@ -24,25 +24,25 @@ fn wcond<T: Copy>(
f: &[T], f: &[T],
layout_f: &Layout, layout_f: &Layout,
) -> Vec<T> { ) -> Vec<T> {
if layout.is_contiguous() && layout_t.is_contiguous() && layout_f.is_contiguous() { match (
let elem_count = layout.shape().elem_count(); layout.contiguous_offsets(),
let offset = layout.start_offset(); layout_t.contiguous_offsets(),
let offset_t = layout_t.start_offset(); layout_f.contiguous_offsets(),
let offset_f = layout_f.start_offset(); ) {
let pred = &pred[offset..offset + elem_count]; (Some((o1, o2)), Some((o_t1, o_t2)), Some((o_f1, o_f2))) => {
let t = &t[offset_t..offset_t + elem_count]; let pred = &pred[o1..o2];
let f = &f[offset_f..offset_f + elem_count]; let t = &t[o_t1..o_t2];
pred.iter() let f = &f[o_f1..o_f2];
.zip(t.iter().zip(f.iter())) pred.iter()
.map(|(&p, (&t, &f))| if p > 0 { t } else { f }) .zip(t.iter().zip(f.iter()))
.collect::<Vec<_>>() .map(|(&p, (&t, &f))| if p > 0 { t } else { f })
} else { .collect::<Vec<_>>()
let it_p = StridedIndex::new(layout); }
let it_t = StridedIndex::new(layout_t); _ => layout
let it_f = StridedIndex::new(layout_f); .strided_index()
it_p.zip(it_t.zip(it_f)) .zip(layout_t.strided_index().zip(layout_f.strided_index()))
.map(|(i_p, (i_t, i_f))| if pred[i_p] > 0 { t[i_t] } else { f[i_f] }) .map(|(i_p, (i_t, i_f))| if pred[i_p] > 0 { t[i_t] } else { f[i_f] })
.collect::<Vec<_>>() .collect::<Vec<_>>(),
} }
} }
@ -62,42 +62,38 @@ macro_rules! map1 {
fn sum_impl1<T: Copy + num_traits::NumAssign>( fn sum_impl1<T: Copy + num_traits::NumAssign>(
src: &[T], src: &[T],
dst_shape: &Shape, dst_shape: &Shape,
src_dims: &[usize], src_layout: &Layout,
stride: &[usize],
to_dst_index: impl Fn(usize) -> usize, to_dst_index: impl Fn(usize) -> usize,
) -> Result<Vec<T>> { ) -> Result<Vec<T>> {
let mut dst = vec![T::zero(); dst_shape.elem_count()]; let mut dst = vec![T::zero(); dst_shape.elem_count()];
for (unstr_index, src_index) in StridedIndex::new(src_dims, stride).enumerate() { for (unstr_index, src_index) in src_layout.strided_index().enumerate() {
dst[to_dst_index(unstr_index)] += src[src_index]; dst[to_dst_index(unstr_index)] += src[src_index];
} }
Ok(dst) Ok(dst)
} }
fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(vs: &[T], layout: &Layout, mut f: F) -> Vec<U> { fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(vs: &[T], layout: &Layout, mut f: F) -> Vec<U> {
if shape.is_contiguous(stride) { match layout.contiguous_offsets() {
vs[..shape.elem_count()].iter().map(|&v| f(v)).collect() Some((o1, o2)) => vs[o1..o2].iter().map(|&v| f(v)).collect(),
} else { None => layout.strided_index().map(|i| f(vs[i])).collect(),
StridedIndex::new(shape.dims(), stride)
.map(|i| f(vs[i]))
.collect()
} }
} }
// This function maps over two strided index sequences. // This function maps over two strided index sequences.
fn binary_map<T: Copy, F: FnMut(T, T) -> T>( fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
shape: &Shape, shape: &Shape,
lhs_stride: &[usize], lhs_layout: &Layout,
rhs_stride: &[usize], rhs_layout: &Layout,
lhs: &[T], lhs: &[T],
rhs: &[T], rhs: &[T],
mut f: F, mut f: F,
) -> Vec<T> { ) -> Vec<T> {
let dims = shape.dims(); let dims = shape.dims();
if shape.is_contiguous(lhs_stride) && shape.is_contiguous(rhs_stride) { if lhs_layout.is_contiguous() && rhs_layout.is_contiguous() {
(0..shape.elem_count()).map(|i| f(lhs[i], rhs[i])).collect() (0..shape.elem_count()).map(|i| f(lhs[i], rhs[i])).collect()
} else { } else {
let lhs_index = StridedIndex::new(dims, lhs_stride); let lhs_index = lhs_layout.strided_index();
let rhs_index = StridedIndex::new(dims, rhs_stride); let rhs_index = rhs_layout.strided_index();
lhs_index lhs_index
.zip(rhs_index) .zip(rhs_index)
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
@ -114,7 +110,7 @@ fn take_impl1<T: Copy>(
) -> Result<Vec<T>> { ) -> Result<Vec<T>> {
// TODO: Optimize for the case where ids are contiguous. // TODO: Optimize for the case where ids are contiguous.
let mut values = Vec::with_capacity(layout.shape().elem_count() * hidden_size); let mut values = Vec::with_capacity(layout.shape().elem_count() * hidden_size);
for index in StridedIndex::new(layout) { for index in layout.strided_index() {
let index = ids[index].try_into()?; let index = ids[index].try_into()?;
if index >= vocab_size { if index >= vocab_size {
return Err(Error::InvalidIndex { return Err(Error::InvalidIndex {
@ -135,18 +131,19 @@ fn copy_strided_src_<T: Copy + std::fmt::Display>(
dst_offset: usize, dst_offset: usize,
src_l: &Layout, src_l: &Layout,
) { ) {
let src = &src[src_l.start_offset()..]; match src_l.contiguous_offsets() {
if src_l.is_contiguous() { Some((o_dst1, o_dst2)) => {
let elem_to_copy = (dst.len() - dst_offset).min(src.len()); let elem_to_copy = (dst.len() - dst_offset).min(o_dst2 - o_dst1);
dst[dst_offset..dst_offset + elem_to_copy].copy_from_slice(&src[..elem_to_copy]) dst[dst_offset..dst_offset + elem_to_copy].copy_from_slice(&src[o_dst1..o_dst2])
} else { }
let src_indexes = StridedIndex::new(src_l); None => {
for (dst_index, src_index) in src_indexes.enumerate() { for (dst_index, src_index) in src_l.strided_index().enumerate() {
let dst_index = dst_index + dst_offset; let dst_index = dst_index + dst_offset;
if dst_index >= dst.len() { if dst_index >= dst.len() {
break; break;
}
dst[dst_index] = src[src_index]
} }
dst[dst_index] = src[src_index]
} }
} }
} }
@ -235,114 +232,114 @@ impl CpuStorage {
D::cpu_storage_as_mut_slice(self) D::cpu_storage_as_mut_slice(self)
} }
pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result<Self> { pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
// TODO: find a way around the quadratic number of cases below. // TODO: find a way around the quadratic number of cases below.
match (self, dtype) { match (self, dtype) {
(Self::U32(storage), DType::BF16) => { (Self::U32(storage), DType::BF16) => {
let data = unary_map(storage, shape, stride, |v| bf16::from_f32(v as f32)); let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
Ok(Self::BF16(data)) Ok(Self::BF16(data))
} }
(Self::BF16(storage), DType::BF16) => { (Self::BF16(storage), DType::BF16) => {
let data = unary_map(storage, shape, stride, |v| v); let data = unary_map(storage, layout, |v| v);
Ok(Self::BF16(data)) Ok(Self::BF16(data))
} }
(Self::F16(storage), DType::BF16) => { (Self::F16(storage), DType::BF16) => {
let data = unary_map(storage, shape, stride, |v| bf16::from_f32(v.to_f32())); let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32()));
Ok(Self::BF16(data)) Ok(Self::BF16(data))
} }
(Self::F32(storage), DType::BF16) => { (Self::F32(storage), DType::BF16) => {
let data = unary_map(storage, shape, stride, bf16::from_f32); let data = unary_map(storage, layout, bf16::from_f32);
Ok(Self::BF16(data)) Ok(Self::BF16(data))
} }
(Self::F64(storage), DType::BF16) => { (Self::F64(storage), DType::BF16) => {
let data = unary_map(storage, shape, stride, bf16::from_f64); let data = unary_map(storage, layout, bf16::from_f64);
Ok(Self::BF16(data)) Ok(Self::BF16(data))
} }
(Self::U32(storage), DType::F16) => { (Self::U32(storage), DType::F16) => {
let data = unary_map(storage, shape, stride, |v| f16::from_f32(v as f32)); let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
Ok(Self::F16(data)) Ok(Self::F16(data))
} }
(Self::BF16(storage), DType::F16) => { (Self::BF16(storage), DType::F16) => {
let data = unary_map(storage, shape, stride, |v| f16::from_f32(v.to_f32())); let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32()));
Ok(Self::F16(data)) Ok(Self::F16(data))
} }
(Self::F16(storage), DType::F16) => { (Self::F16(storage), DType::F16) => {
let data = unary_map(storage, shape, stride, |v| v); let data = unary_map(storage, layout, |v| v);
Ok(Self::F16(data)) Ok(Self::F16(data))
} }
(Self::F32(storage), DType::F16) => { (Self::F32(storage), DType::F16) => {
let data = unary_map(storage, shape, stride, f16::from_f32); let data = unary_map(storage, layout, f16::from_f32);
Ok(Self::F16(data)) Ok(Self::F16(data))
} }
(Self::F64(storage), DType::F16) => { (Self::F64(storage), DType::F16) => {
let data = unary_map(storage, shape, stride, f16::from_f64); let data = unary_map(storage, layout, f16::from_f64);
Ok(Self::F16(data)) Ok(Self::F16(data))
} }
(Self::U32(storage), DType::F32) => { (Self::U32(storage), DType::F32) => {
let data = unary_map(storage, shape, stride, |v| v as f32); let data = unary_map(storage, layout, |v| v as f32);
Ok(Self::F32(data)) Ok(Self::F32(data))
} }
(Self::BF16(storage), DType::F32) => { (Self::BF16(storage), DType::F32) => {
let data = unary_map(storage, shape, stride, |v| v.to_f32()); let data = unary_map(storage, layout, |v| v.to_f32());
Ok(Self::F32(data)) Ok(Self::F32(data))
} }
(Self::F16(storage), DType::F32) => { (Self::F16(storage), DType::F32) => {
let data = unary_map(storage, shape, stride, |v| v.to_f32()); let data = unary_map(storage, layout, |v| v.to_f32());
Ok(Self::F32(data)) Ok(Self::F32(data))
} }
(Self::F32(storage), DType::F32) => { (Self::F32(storage), DType::F32) => {
let data = unary_map(storage, shape, stride, |v| v); let data = unary_map(storage, layout, |v| v);
Ok(Self::F32(data)) Ok(Self::F32(data))
} }
(Self::F64(storage), DType::F32) => { (Self::F64(storage), DType::F32) => {
let data = unary_map(storage, shape, stride, |v| v as f32); let data = unary_map(storage, layout, |v| v as f32);
Ok(Self::F32(data)) Ok(Self::F32(data))
} }
(Self::U32(storage), DType::U32) => { (Self::U32(storage), DType::U32) => {
let data = unary_map(storage, shape, stride, |v| v); let data = unary_map(storage, layout, |v| v);
Ok(Self::U32(data)) Ok(Self::U32(data))
} }
(Self::BF16(storage), DType::U32) => { (Self::BF16(storage), DType::U32) => {
let data = unary_map(storage, shape, stride, |v| v.to_f32() as u32); let data = unary_map(storage, layout, |v| v.to_f32() as u32);
Ok(Self::U32(data)) Ok(Self::U32(data))
} }
(Self::F16(storage), DType::U32) => { (Self::F16(storage), DType::U32) => {
let data = unary_map(storage, shape, stride, |v| v.to_f32() as u32); let data = unary_map(storage, layout, |v| v.to_f32() as u32);
Ok(Self::U32(data)) Ok(Self::U32(data))
} }
(Self::F32(storage), DType::U32) => { (Self::F32(storage), DType::U32) => {
let data = unary_map(storage, shape, stride, |v| v as u32); let data = unary_map(storage, layout, |v| v as u32);
Ok(Self::U32(data)) Ok(Self::U32(data))
} }
(Self::F64(storage), DType::U32) => { (Self::F64(storage), DType::U32) => {
let data = unary_map(storage, shape, stride, |v| v as u32); let data = unary_map(storage, layout, |v| v as u32);
Ok(Self::U32(data)) Ok(Self::U32(data))
} }
(Self::U32(storage), DType::F64) => { (Self::U32(storage), DType::F64) => {
let data = unary_map(storage, shape, stride, |v| v as f64); let data = unary_map(storage, layout, |v| v as f64);
Ok(Self::F64(data)) Ok(Self::F64(data))
} }
(Self::BF16(storage), DType::F64) => { (Self::BF16(storage), DType::F64) => {
let data = unary_map(storage, shape, stride, |v| v.to_f64()); let data = unary_map(storage, layout, |v| v.to_f64());
Ok(Self::F64(data)) Ok(Self::F64(data))
} }
(Self::F16(storage), DType::F64) => { (Self::F16(storage), DType::F64) => {
let data = unary_map(storage, shape, stride, |v| v.to_f64()); let data = unary_map(storage, layout, |v| v.to_f64());
Ok(Self::F64(data)) Ok(Self::F64(data))
} }
(Self::F32(storage), DType::F64) => { (Self::F32(storage), DType::F64) => {
let data = unary_map(storage, shape, stride, |v| v as f64); let data = unary_map(storage, layout, |v| v as f64);
Ok(Self::F64(data)) Ok(Self::F64(data))
} }
(Self::F64(storage), DType::F64) => { (Self::F64(storage), DType::F64) => {
let data = unary_map(storage, shape, stride, |v| v); let data = unary_map(storage, layout, |v| v);
Ok(Self::F64(data)) Ok(Self::F64(data))
} }
} }
} }
pub(crate) fn sum(&self, shape: &Shape, stride: &[usize], sum_dims: &[usize]) -> Result<Self> { pub(crate) fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
let src_dims = shape.dims(); let src_dims = layout.dims();
let mut dst_dims = src_dims.to_vec(); let mut dst_dims = src_dims.to_vec();
for &sum_dim in sum_dims.iter() { for &sum_dim in sum_dims.iter() {
dst_dims[sum_dim] = 1; dst_dims[sum_dim] = 1;
@ -368,7 +365,7 @@ impl CpuStorage {
dst_index dst_index
}; };
// TODO: Maybe provide an implementation with higher precision accumulators? // TODO: Maybe provide an implementation with higher precision accumulators?
map1!(self, sum_impl1, &dst_shape, src_dims, stride, to_dst_index) map1!(self, sum_impl1, &dst_shape, layout, to_dst_index)
} }
pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> { pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> {
@ -516,28 +513,28 @@ impl CpuStorage {
&self, &self,
rhs: &Self, rhs: &Self,
shape: &Shape, shape: &Shape,
lhs_stride: &[usize], lhs_layout: &Layout,
rhs_stride: &[usize], rhs_layout: &Layout,
) -> Result<Self> { ) -> Result<Self> {
match (self, rhs) { match (self, rhs) {
(Self::BF16(lhs), Self::BF16(rhs)) => { (Self::BF16(lhs), Self::BF16(rhs)) => {
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::bf16); let data = binary_map(shape, lhs_layout, rhs_layout, lhs, rhs, B::bf16);
Ok(Self::BF16(data)) Ok(Self::BF16(data))
} }
(Self::F16(lhs), Self::F16(rhs)) => { (Self::F16(lhs), Self::F16(rhs)) => {
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f16); let data = binary_map(shape, lhs_layout, rhs_layout, lhs, rhs, B::f16);
Ok(Self::F16(data)) Ok(Self::F16(data))
} }
(Self::F32(lhs), Self::F32(rhs)) => { (Self::F32(lhs), Self::F32(rhs)) => {
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f32); let data = binary_map(shape, lhs_layout, rhs_layout, lhs, rhs, B::f32);
Ok(Self::F32(data)) Ok(Self::F32(data))
} }
(Self::F64(lhs), Self::F64(rhs)) => { (Self::F64(lhs), Self::F64(rhs)) => {
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f64); let data = binary_map(shape, lhs_layout, rhs_layout, lhs, rhs, B::f64);
Ok(Self::F64(data)) Ok(Self::F64(data))
} }
(Self::U32(lhs), Self::U32(rhs)) => { (Self::U32(lhs), Self::U32(rhs)) => {
let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::u32); let data = binary_map(shape, lhs_layout, rhs_layout, lhs, rhs, B::u32);
Ok(Self::U32(data)) Ok(Self::U32(data))
} }
_ => { _ => {
@ -555,7 +552,7 @@ impl CpuStorage {
&self, &self,
dst: &mut Self, dst: &mut Self,
dst_offset: usize, dst_offset: usize,
src_l: Layout, src_l: &Layout,
) -> Result<()> { ) -> Result<()> {
match (self, dst) { match (self, dst) {
(Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),

View File

@ -39,6 +39,17 @@ impl Layout {
self.start_offset self.start_offset
} }
/// Returns the appropriate start and stop offset if the data is stored in a C
/// contiguous (aka row major) way.
pub fn contiguous_offsets(&self) -> Option<(usize, usize)> {
if self.is_contiguous() {
let start_o = self.start_offset;
Some((start_o, start_o + self.shape.elem_count()))
} else {
None
}
}
/// Returns true if the data is stored in a C contiguous (aka row major) way. /// Returns true if the data is stored in a C contiguous (aka row major) way.
pub fn is_contiguous(&self) -> bool { pub fn is_contiguous(&self) -> bool {
self.shape.is_contiguous(&self.stride) self.shape.is_contiguous(&self.stride)