diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index f1118ee7..925ca112 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1,5 +1,5 @@ use crate::backend::{BackendDevice, BackendStorage}; -use crate::op::{BinaryOp, UnaryOp}; +use crate::op::{BinaryOp, ReduceOp, UnaryOp}; use crate::{DType, Error, Layout, Result, Shape, WithDType}; use half::{bf16, f16}; @@ -97,17 +97,17 @@ struct Reduce<'a> { dst_shape: &'a Shape, reduce_dims: &'a [usize], reduce_dims_and_stride: Vec<(usize, usize)>, - op: crate::op::ReduceOp, + op: ReduceOp, } -impl<'a> Map1 for Reduce<'a> { +impl<'a> Reduce<'a> { #[inline(always)] - fn f(&self, src: &[T], src_l: &Layout) -> Result> { - match self.op { - crate::op::ReduceOp::Min | crate::op::ReduceOp::Max => todo!(), - crate::op::ReduceOp::Sum => (), - } - let mut dst = vec![T::zero(); self.dst_shape.elem_count()]; + fn fold_impl(&self, src: &[T], src_l: &Layout, start_elt: T, f: F) -> Result> + where + T: Clone + Copy, + F: Fn(T, T) -> T, + { + let mut dst = vec![start_elt; self.dst_shape.elem_count()]; match src_l.contiguous_offsets() { Some((o1, o2)) => { let src = &src[o1..o2]; @@ -129,7 +129,7 @@ impl<'a> Map1 for Reduce<'a> { let mut src_i = 0; for dst_v in dst.iter_mut() { for &s in src[src_i..src_i + reduce_sz].iter() { - *dst_v += s + *dst_v = f(*dst_v, s) } src_i += reduce_sz } @@ -143,7 +143,7 @@ impl<'a> Map1 for Reduce<'a> { let (pre, post) = (dst_index / stride, dst_index % stride); dst_index = (pre / dim) * stride + post; } - dst[dst_index] += src; + dst[dst_index] = f(dst[dst_index], src); } } None => { @@ -155,7 +155,7 @@ impl<'a> Map1 for Reduce<'a> { let (pre, post) = (dst_index / stride, dst_index % stride); dst_index = (pre / dim) * stride + post; } - dst[dst_index] += src[src_index]; + dst[dst_index] = f(dst[dst_index], src[src_index]); } } } @@ -163,6 +163,31 @@ impl<'a> Map1 for Reduce<'a> { } } +impl<'a> Map1 for Reduce<'a> { + #[inline(always)] + fn f(&self, src: &[T], src_l: &Layout) -> Result> { + match self.op { + ReduceOp::Min => { + let s = if src_l.shape().elem_count() != 0 { + src[src_l.start_offset()] + } else { + Err(Error::EmptyTensor { op: "min" }.bt())? + }; + self.fold_impl(src, src_l, s, |x, y| if x < y { x } else { y }) + } + ReduceOp::Max => { + let s = if src_l.shape().elem_count() != 0 { + src[src_l.start_offset()] + } else { + Err(Error::EmptyTensor { op: "max" }.bt())? + }; + self.fold_impl(src, src_l, s, |x, y| if x > y { x } else { y }) + } + ReduceOp::Sum => self.fold_impl(src, src_l, T::zero(), |x, y| x + y), + } + } +} + fn unary_map U>(vs: &[T], layout: &Layout, mut f: F) -> Vec { match layout.strided_blocks() { crate::StridedBlocks::SingleBlock { start_offset, len } => vs @@ -1015,12 +1040,7 @@ impl BackendStorage for CpuStorage { } } - fn reduce_op( - &self, - op: crate::op::ReduceOp, - layout: &Layout, - reduce_dims: &[usize], - ) -> Result { + fn reduce_op(&self, op: ReduceOp, layout: &Layout, reduce_dims: &[usize]) -> Result { let src_dims = layout.dims(); let mut dst_dims = src_dims.to_vec(); for &dim in reduce_dims.iter() { diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index e354b239..4ec639db 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -79,6 +79,9 @@ pub enum Error { nth_shape: Shape, }, + #[error("empty tensor for {op}")] + EmptyTensor { op: &'static str }, + // === Device Errors === #[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")] DeviceMismatchBinaryOp { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 32c8acd6..276a522e 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1,6 +1,7 @@ use crate::backend::{BackendDevice, BackendStorage}; +use crate::op::{Op, ReduceOp}; use crate::shape::{Dim, Dims}; -use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape}; +use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape}; use std::sync::{Arc, RwLock}; /// Unique identifier for tensors. @@ -629,9 +630,9 @@ impl Tensor { fn max_impl(&self, max_dims: D, keepdim: bool) -> Result { let max_dims = max_dims.to_indexes(self.shape(), "max")?; - let storage = - self.storage() - .reduce_op(crate::op::ReduceOp::Max, self.layout(), &max_dims)?; + let storage = self + .storage() + .reduce_op(ReduceOp::Max, self.layout(), &max_dims)?; let op = if self.track_op() { Some(Op::Max(self.clone(), max_dims.to_vec())) } else { @@ -651,9 +652,9 @@ impl Tensor { fn min_impl(&self, min_dims: D, keepdim: bool) -> Result { let min_dims = min_dims.to_indexes(self.shape(), "min")?; - let storage = - self.storage() - .reduce_op(crate::op::ReduceOp::Min, self.layout(), &min_dims)?; + let storage = self + .storage() + .reduce_op(ReduceOp::Min, self.layout(), &min_dims)?; let op = if self.track_op() { Some(Op::Min(self.clone(), min_dims.to_vec())) } else { @@ -673,9 +674,9 @@ impl Tensor { fn sum_impl(&self, sum_dims: D, keepdim: bool) -> Result { let sum_dims = sum_dims.to_indexes(self.shape(), "sum")?; - let storage = - self.storage() - .reduce_op(crate::op::ReduceOp::Sum, self.layout(), &sum_dims)?; + let storage = self + .storage() + .reduce_op(ReduceOp::Sum, self.layout(), &sum_dims)?; let op = if self.track_op() { Some(Op::Sum(self.clone(), sum_dims.to_vec())) } else { @@ -729,6 +730,11 @@ impl Tensor { self.max_impl(max_dims, false) } + pub fn max_all(&self) -> Result { + let dims: Vec<_> = (0..self.rank()).collect(); + self.max(dims) + } + pub fn min_keepdim(&self, min_dims: D) -> Result { self.min_impl(min_dims, true) } @@ -737,6 +743,11 @@ impl Tensor { self.min_impl(min_dims, false) } + pub fn min_all(&self) -> Result { + let dims: Vec<_> = (0..self.rank()).collect(); + self.min(dims) + } + /// Applies a 1D convolution over the input tensor. pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result { let (c_out, c_in_k, k_size) = kernel.shape().r3()?;