mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Add cpu support for min and max. (#202)
* Add cpu support for min and max. * Add min/max all.
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{BinaryOp, UnaryOp};
|
||||
use crate::op::{BinaryOp, ReduceOp, UnaryOp};
|
||||
use crate::{DType, Error, Layout, Result, Shape, WithDType};
|
||||
use half::{bf16, f16};
|
||||
|
||||
@ -97,17 +97,17 @@ struct Reduce<'a> {
|
||||
dst_shape: &'a Shape,
|
||||
reduce_dims: &'a [usize],
|
||||
reduce_dims_and_stride: Vec<(usize, usize)>,
|
||||
op: crate::op::ReduceOp,
|
||||
op: ReduceOp,
|
||||
}
|
||||
|
||||
impl<'a> Map1 for Reduce<'a> {
|
||||
impl<'a> Reduce<'a> {
|
||||
#[inline(always)]
|
||||
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||
match self.op {
|
||||
crate::op::ReduceOp::Min | crate::op::ReduceOp::Max => todo!(),
|
||||
crate::op::ReduceOp::Sum => (),
|
||||
}
|
||||
let mut dst = vec![T::zero(); self.dst_shape.elem_count()];
|
||||
fn fold_impl<T, F>(&self, src: &[T], src_l: &Layout, start_elt: T, f: F) -> Result<Vec<T>>
|
||||
where
|
||||
T: Clone + Copy,
|
||||
F: Fn(T, T) -> T,
|
||||
{
|
||||
let mut dst = vec![start_elt; self.dst_shape.elem_count()];
|
||||
match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => {
|
||||
let src = &src[o1..o2];
|
||||
@ -129,7 +129,7 @@ impl<'a> Map1 for Reduce<'a> {
|
||||
let mut src_i = 0;
|
||||
for dst_v in dst.iter_mut() {
|
||||
for &s in src[src_i..src_i + reduce_sz].iter() {
|
||||
*dst_v += s
|
||||
*dst_v = f(*dst_v, s)
|
||||
}
|
||||
src_i += reduce_sz
|
||||
}
|
||||
@ -143,7 +143,7 @@ impl<'a> Map1 for Reduce<'a> {
|
||||
let (pre, post) = (dst_index / stride, dst_index % stride);
|
||||
dst_index = (pre / dim) * stride + post;
|
||||
}
|
||||
dst[dst_index] += src;
|
||||
dst[dst_index] = f(dst[dst_index], src);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
@ -155,7 +155,7 @@ impl<'a> Map1 for Reduce<'a> {
|
||||
let (pre, post) = (dst_index / stride, dst_index % stride);
|
||||
dst_index = (pre / dim) * stride + post;
|
||||
}
|
||||
dst[dst_index] += src[src_index];
|
||||
dst[dst_index] = f(dst[dst_index], src[src_index]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -163,6 +163,31 @@ impl<'a> Map1 for Reduce<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Map1 for Reduce<'a> {
|
||||
#[inline(always)]
|
||||
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||
match self.op {
|
||||
ReduceOp::Min => {
|
||||
let s = if src_l.shape().elem_count() != 0 {
|
||||
src[src_l.start_offset()]
|
||||
} else {
|
||||
Err(Error::EmptyTensor { op: "min" }.bt())?
|
||||
};
|
||||
self.fold_impl(src, src_l, s, |x, y| if x < y { x } else { y })
|
||||
}
|
||||
ReduceOp::Max => {
|
||||
let s = if src_l.shape().elem_count() != 0 {
|
||||
src[src_l.start_offset()]
|
||||
} else {
|
||||
Err(Error::EmptyTensor { op: "max" }.bt())?
|
||||
};
|
||||
self.fold_impl(src, src_l, s, |x, y| if x > y { x } else { y })
|
||||
}
|
||||
ReduceOp::Sum => self.fold_impl(src, src_l, T::zero(), |x, y| x + y),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(vs: &[T], layout: &Layout, mut f: F) -> Vec<U> {
|
||||
match layout.strided_blocks() {
|
||||
crate::StridedBlocks::SingleBlock { start_offset, len } => vs
|
||||
@ -1015,12 +1040,7 @@ impl BackendStorage for CpuStorage {
|
||||
}
|
||||
}
|
||||
|
||||
fn reduce_op(
|
||||
&self,
|
||||
op: crate::op::ReduceOp,
|
||||
layout: &Layout,
|
||||
reduce_dims: &[usize],
|
||||
) -> Result<Self> {
|
||||
fn reduce_op(&self, op: ReduceOp, layout: &Layout, reduce_dims: &[usize]) -> Result<Self> {
|
||||
let src_dims = layout.dims();
|
||||
let mut dst_dims = src_dims.to_vec();
|
||||
for &dim in reduce_dims.iter() {
|
||||
|
@ -79,6 +79,9 @@ pub enum Error {
|
||||
nth_shape: Shape,
|
||||
},
|
||||
|
||||
#[error("empty tensor for {op}")]
|
||||
EmptyTensor { op: &'static str },
|
||||
|
||||
// === Device Errors ===
|
||||
#[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
|
||||
DeviceMismatchBinaryOp {
|
||||
|
@ -1,6 +1,7 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{Op, ReduceOp};
|
||||
use crate::shape::{Dim, Dims};
|
||||
use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
/// Unique identifier for tensors.
|
||||
@ -629,9 +630,9 @@ impl Tensor {
|
||||
|
||||
fn max_impl<D: Dims>(&self, max_dims: D, keepdim: bool) -> Result<Self> {
|
||||
let max_dims = max_dims.to_indexes(self.shape(), "max")?;
|
||||
let storage =
|
||||
self.storage()
|
||||
.reduce_op(crate::op::ReduceOp::Max, self.layout(), &max_dims)?;
|
||||
let storage = self
|
||||
.storage()
|
||||
.reduce_op(ReduceOp::Max, self.layout(), &max_dims)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Max(self.clone(), max_dims.to_vec()))
|
||||
} else {
|
||||
@ -651,9 +652,9 @@ impl Tensor {
|
||||
|
||||
fn min_impl<D: Dims>(&self, min_dims: D, keepdim: bool) -> Result<Self> {
|
||||
let min_dims = min_dims.to_indexes(self.shape(), "min")?;
|
||||
let storage =
|
||||
self.storage()
|
||||
.reduce_op(crate::op::ReduceOp::Min, self.layout(), &min_dims)?;
|
||||
let storage = self
|
||||
.storage()
|
||||
.reduce_op(ReduceOp::Min, self.layout(), &min_dims)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Min(self.clone(), min_dims.to_vec()))
|
||||
} else {
|
||||
@ -673,9 +674,9 @@ impl Tensor {
|
||||
|
||||
fn sum_impl<D: Dims>(&self, sum_dims: D, keepdim: bool) -> Result<Self> {
|
||||
let sum_dims = sum_dims.to_indexes(self.shape(), "sum")?;
|
||||
let storage =
|
||||
self.storage()
|
||||
.reduce_op(crate::op::ReduceOp::Sum, self.layout(), &sum_dims)?;
|
||||
let storage = self
|
||||
.storage()
|
||||
.reduce_op(ReduceOp::Sum, self.layout(), &sum_dims)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Sum(self.clone(), sum_dims.to_vec()))
|
||||
} else {
|
||||
@ -729,6 +730,11 @@ impl Tensor {
|
||||
self.max_impl(max_dims, false)
|
||||
}
|
||||
|
||||
pub fn max_all(&self) -> Result<Tensor> {
|
||||
let dims: Vec<_> = (0..self.rank()).collect();
|
||||
self.max(dims)
|
||||
}
|
||||
|
||||
pub fn min_keepdim<D: Dims>(&self, min_dims: D) -> Result<Self> {
|
||||
self.min_impl(min_dims, true)
|
||||
}
|
||||
@ -737,6 +743,11 @@ impl Tensor {
|
||||
self.min_impl(min_dims, false)
|
||||
}
|
||||
|
||||
pub fn min_all(&self) -> Result<Tensor> {
|
||||
let dims: Vec<_> = (0..self.rank()).collect();
|
||||
self.min(dims)
|
||||
}
|
||||
|
||||
/// Applies a 1D convolution over the input tensor.
|
||||
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
|
||||
let (c_out, c_in_k, k_size) = kernel.shape().r3()?;
|
||||
|
Reference in New Issue
Block a user