Add cpu support for min and max. (#202)

* Add cpu support for min and max.

* Add min/max all.
This commit is contained in:
Laurent Mazare
2023-07-19 18:11:44 +02:00
committed by GitHub
parent e6584476c4
commit ad12e20f6b
3 changed files with 62 additions and 28 deletions

View File

@ -1,5 +1,5 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOp, UnaryOp};
use crate::op::{BinaryOp, ReduceOp, UnaryOp};
use crate::{DType, Error, Layout, Result, Shape, WithDType};
use half::{bf16, f16};
@ -97,17 +97,17 @@ struct Reduce<'a> {
dst_shape: &'a Shape,
reduce_dims: &'a [usize],
reduce_dims_and_stride: Vec<(usize, usize)>,
op: crate::op::ReduceOp,
op: ReduceOp,
}
impl<'a> Map1 for Reduce<'a> {
impl<'a> Reduce<'a> {
#[inline(always)]
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
match self.op {
crate::op::ReduceOp::Min | crate::op::ReduceOp::Max => todo!(),
crate::op::ReduceOp::Sum => (),
}
let mut dst = vec![T::zero(); self.dst_shape.elem_count()];
fn fold_impl<T, F>(&self, src: &[T], src_l: &Layout, start_elt: T, f: F) -> Result<Vec<T>>
where
T: Clone + Copy,
F: Fn(T, T) -> T,
{
let mut dst = vec![start_elt; self.dst_shape.elem_count()];
match src_l.contiguous_offsets() {
Some((o1, o2)) => {
let src = &src[o1..o2];
@ -129,7 +129,7 @@ impl<'a> Map1 for Reduce<'a> {
let mut src_i = 0;
for dst_v in dst.iter_mut() {
for &s in src[src_i..src_i + reduce_sz].iter() {
*dst_v += s
*dst_v = f(*dst_v, s)
}
src_i += reduce_sz
}
@ -143,7 +143,7 @@ impl<'a> Map1 for Reduce<'a> {
let (pre, post) = (dst_index / stride, dst_index % stride);
dst_index = (pre / dim) * stride + post;
}
dst[dst_index] += src;
dst[dst_index] = f(dst[dst_index], src);
}
}
None => {
@ -155,7 +155,7 @@ impl<'a> Map1 for Reduce<'a> {
let (pre, post) = (dst_index / stride, dst_index % stride);
dst_index = (pre / dim) * stride + post;
}
dst[dst_index] += src[src_index];
dst[dst_index] = f(dst[dst_index], src[src_index]);
}
}
}
@ -163,6 +163,31 @@ impl<'a> Map1 for Reduce<'a> {
}
}
impl<'a> Map1 for Reduce<'a> {
#[inline(always)]
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
match self.op {
ReduceOp::Min => {
let s = if src_l.shape().elem_count() != 0 {
src[src_l.start_offset()]
} else {
Err(Error::EmptyTensor { op: "min" }.bt())?
};
self.fold_impl(src, src_l, s, |x, y| if x < y { x } else { y })
}
ReduceOp::Max => {
let s = if src_l.shape().elem_count() != 0 {
src[src_l.start_offset()]
} else {
Err(Error::EmptyTensor { op: "max" }.bt())?
};
self.fold_impl(src, src_l, s, |x, y| if x > y { x } else { y })
}
ReduceOp::Sum => self.fold_impl(src, src_l, T::zero(), |x, y| x + y),
}
}
}
fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(vs: &[T], layout: &Layout, mut f: F) -> Vec<U> {
match layout.strided_blocks() {
crate::StridedBlocks::SingleBlock { start_offset, len } => vs
@ -1015,12 +1040,7 @@ impl BackendStorage for CpuStorage {
}
}
fn reduce_op(
&self,
op: crate::op::ReduceOp,
layout: &Layout,
reduce_dims: &[usize],
) -> Result<Self> {
fn reduce_op(&self, op: ReduceOp, layout: &Layout, reduce_dims: &[usize]) -> Result<Self> {
let src_dims = layout.dims();
let mut dst_dims = src_dims.to_vec();
for &dim in reduce_dims.iter() {