mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Refactor the reduce ops in order to introduce argmin/argmax. (#212)
* Refactor the reduce ops in order to introduce argmin/argmax. * Clippy fixes. * Use the newly introduced argmax. * Fix the strided case. * Handle the non-contiguous case.
This commit is contained in:
@ -304,6 +304,12 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Reduce(_, ReduceOp::ArgMin, _) => {
|
||||
Err(Error::BackwardNotSupported { op: "argmin" })?
|
||||
}
|
||||
Op::Reduce(_, ReduceOp::ArgMax, _) => {
|
||||
Err(Error::BackwardNotSupported { op: "argmax" })?
|
||||
}
|
||||
Op::Softmax(_arg, _) => Err(Error::BackwardNotSupported { op: "softmax" })?,
|
||||
Op::Reshape(arg) => {
|
||||
let arg_grad = grad.reshape(arg.dims())?;
|
||||
|
@ -33,6 +33,26 @@ trait Map1 {
|
||||
}
|
||||
}
|
||||
|
||||
trait Map1Any {
|
||||
fn f<T: WithDType, W: Fn(Vec<T>) -> CpuStorage>(
|
||||
&self,
|
||||
vs: &[T],
|
||||
layout: &Layout,
|
||||
wrap: W,
|
||||
) -> Result<CpuStorage>;
|
||||
|
||||
fn map(&self, vs: &CpuStorage, layout: &Layout) -> Result<CpuStorage> {
|
||||
match vs {
|
||||
CpuStorage::U8(vs) => Ok(self.f(vs, layout, CpuStorage::U8)?),
|
||||
CpuStorage::U32(vs) => Ok(self.f(vs, layout, CpuStorage::U32)?),
|
||||
CpuStorage::BF16(vs) => Ok(self.f(vs, layout, CpuStorage::BF16)?),
|
||||
CpuStorage::F16(vs) => Ok(self.f(vs, layout, CpuStorage::F16)?),
|
||||
CpuStorage::F32(vs) => Ok(self.f(vs, layout, CpuStorage::F32)?),
|
||||
CpuStorage::F64(vs) => Ok(self.f(vs, layout, CpuStorage::F64)?),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type C = CpuStorage;
|
||||
trait Map2 {
|
||||
const OP: &'static str;
|
||||
@ -144,11 +164,118 @@ impl<'a> Map2 for WCond<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
struct ReduceIndex {
|
||||
reduce_dim_index: usize,
|
||||
use_min: bool,
|
||||
return_index: bool,
|
||||
}
|
||||
|
||||
impl ReduceIndex {
|
||||
// The value gets replaced if f(s[current_acc], s[i]) returns true.
|
||||
#[inline(always)]
|
||||
fn fold_impl<T, U, F, G>(&self, src: &[T], src_l: &Layout, f: F, g: G) -> Result<Vec<U>>
|
||||
where
|
||||
T: Clone + Copy,
|
||||
U: Clone + Copy,
|
||||
F: Fn(T, T) -> bool,
|
||||
G: Fn(T, usize) -> U,
|
||||
{
|
||||
let reduce_dim_size = src_l.dims()[self.reduce_dim_index];
|
||||
let reduce_dim_stride = src_l.stride()[self.reduce_dim_index];
|
||||
let dst_len = src_l.shape().elem_count() / reduce_dim_size;
|
||||
let mut dst: Vec<U> = Vec::with_capacity(dst_len);
|
||||
let dst_to_set = dst.spare_capacity_mut();
|
||||
let dst_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(dst_to_set) };
|
||||
match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => {
|
||||
let src = &src[o1..o2];
|
||||
if reduce_dim_stride == 1 {
|
||||
for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() {
|
||||
let start_src_i = start_src_i * reduce_dim_size;
|
||||
let src = &src[start_src_i..start_src_i + reduce_dim_size];
|
||||
let mut acc = 0;
|
||||
let mut val = src[0];
|
||||
for (src_i, &s) in src.iter().enumerate() {
|
||||
if f(val, s) {
|
||||
acc = src_i;
|
||||
val = s
|
||||
}
|
||||
}
|
||||
*dst_v = g(val, acc)
|
||||
}
|
||||
} else {
|
||||
for (start_src_i, dst_v) in dst_to_set.iter_mut().enumerate() {
|
||||
let (p, q) = (
|
||||
start_src_i / reduce_dim_stride,
|
||||
start_src_i % reduce_dim_stride,
|
||||
);
|
||||
// start_src_i = p * reduce_dim_stride + q
|
||||
let start_src_i = p * reduce_dim_stride * reduce_dim_size + q;
|
||||
let src = &src[start_src_i..];
|
||||
let mut acc = 0;
|
||||
let mut val = src[0];
|
||||
for src_i in 0..reduce_dim_size {
|
||||
let s = src[src_i * reduce_dim_stride];
|
||||
if f(val, s) {
|
||||
acc = src_i;
|
||||
val = s
|
||||
}
|
||||
}
|
||||
*dst_v = g(val, acc)
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let l = src_l.narrow(self.reduce_dim_index, 0, 1)?;
|
||||
for (unstr_index, src_index) in l.strided_index().enumerate() {
|
||||
let src = &src[src_index..];
|
||||
let mut acc = 0;
|
||||
let mut val = src[0];
|
||||
for src_i in 0..reduce_dim_size {
|
||||
let s = src[src_i * reduce_dim_stride];
|
||||
if f(val, s) {
|
||||
acc = src_i;
|
||||
val = s
|
||||
}
|
||||
}
|
||||
dst[unstr_index] = g(val, acc)
|
||||
}
|
||||
}
|
||||
}
|
||||
unsafe { dst.set_len(dst_len) };
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
impl Map1Any for ReduceIndex {
|
||||
#[inline(always)]
|
||||
fn f<T: WithDType, W: Fn(Vec<T>) -> CpuStorage>(
|
||||
&self,
|
||||
src: &[T],
|
||||
src_l: &Layout,
|
||||
wrap: W,
|
||||
) -> Result<CpuStorage> {
|
||||
if src_l.shape().elem_count() == 0 {
|
||||
Err(Error::EmptyTensor { op: "reduce" }.bt())?
|
||||
}
|
||||
let dst = match (self.return_index, self.use_min) {
|
||||
(false, true) => wrap(self.fold_impl(src, src_l, |x, y| x > y, |v, _i| v)?),
|
||||
(false, false) => wrap(self.fold_impl(src, src_l, |x, y| x < y, |v, _i| v)?),
|
||||
(true, true) => {
|
||||
CpuStorage::U32(self.fold_impl(src, src_l, |x, y| x > y, |_v, i| i as u32)?)
|
||||
}
|
||||
(true, false) => {
|
||||
CpuStorage::U32(self.fold_impl(src, src_l, |x, y| x < y, |_v, i| i as u32)?)
|
||||
}
|
||||
};
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
struct Reduce<'a> {
|
||||
dst_shape: &'a Shape,
|
||||
reduce_dims: &'a [usize],
|
||||
reduce_dims_and_stride: Vec<(usize, usize)>,
|
||||
op: ReduceOp,
|
||||
}
|
||||
|
||||
impl<'a> Reduce<'a> {
|
||||
@ -217,25 +344,7 @@ impl<'a> Reduce<'a> {
|
||||
impl<'a> Map1 for Reduce<'a> {
|
||||
#[inline(always)]
|
||||
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||
match self.op {
|
||||
ReduceOp::Min => {
|
||||
let s = if src_l.shape().elem_count() != 0 {
|
||||
src[src_l.start_offset()]
|
||||
} else {
|
||||
Err(Error::EmptyTensor { op: "min" }.bt())?
|
||||
};
|
||||
self.fold_impl(src, src_l, s, |x, y| if x < y { x } else { y })
|
||||
}
|
||||
ReduceOp::Max => {
|
||||
let s = if src_l.shape().elem_count() != 0 {
|
||||
src[src_l.start_offset()]
|
||||
} else {
|
||||
Err(Error::EmptyTensor { op: "max" }.bt())?
|
||||
};
|
||||
self.fold_impl(src, src_l, s, |x, y| if x > y { x } else { y })
|
||||
}
|
||||
ReduceOp::Sum => self.fold_impl(src, src_l, T::zero(), |x, y| x + y),
|
||||
}
|
||||
self.fold_impl(src, src_l, T::zero(), |x, y| x + y)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1144,6 +1253,8 @@ impl BackendStorage for CpuStorage {
|
||||
}
|
||||
|
||||
fn reduce_op(&self, op: ReduceOp, layout: &Layout, reduce_dims: &[usize]) -> Result<Self> {
|
||||
match op {
|
||||
ReduceOp::Sum => {
|
||||
let src_dims = layout.dims();
|
||||
let mut dst_dims = src_dims.to_vec();
|
||||
for &dim in reduce_dims.iter() {
|
||||
@ -1162,10 +1273,40 @@ impl BackendStorage for CpuStorage {
|
||||
dst_shape: &dst_shape,
|
||||
reduce_dims: &reduce_dims,
|
||||
reduce_dims_and_stride,
|
||||
op,
|
||||
}
|
||||
.map(self, layout)
|
||||
}
|
||||
ReduceOp::Min | ReduceOp::ArgMin | ReduceOp::Max | ReduceOp::ArgMax => {
|
||||
let reduce_dim_index = match reduce_dims {
|
||||
[reduce_dim_index] => *reduce_dim_index,
|
||||
_ => {
|
||||
let op = match op {
|
||||
ReduceOp::Min => "min",
|
||||
ReduceOp::ArgMin => "argmin",
|
||||
ReduceOp::Max => "max",
|
||||
ReduceOp::ArgMax => "argmax",
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let dims = reduce_dims.to_vec();
|
||||
Err(Error::OnlySingleDimension { op, dims })?
|
||||
}
|
||||
};
|
||||
let (use_min, return_index) = match op {
|
||||
ReduceOp::Min => (true, false),
|
||||
ReduceOp::ArgMin => (true, true),
|
||||
ReduceOp::Max => (false, false),
|
||||
ReduceOp::ArgMax => (false, true),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
ReduceIndex {
|
||||
reduce_dim_index,
|
||||
use_min,
|
||||
return_index,
|
||||
}
|
||||
.map(self, layout)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
|
||||
Cmp(op).map(self, lhs_l, rhs, rhs_l)
|
||||
|
@ -562,6 +562,8 @@ impl<'a> Map1 for FastReduce<'a> {
|
||||
ReduceOp::Sum => "fast_sum",
|
||||
ReduceOp::Min => "fast_min",
|
||||
ReduceOp::Max => "fast_max",
|
||||
ReduceOp::ArgMin => "fast_argmin",
|
||||
ReduceOp::ArgMax => "fast_argmax",
|
||||
};
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::REDUCE)?;
|
||||
// SAFETY: filled in by the follow up kernel.
|
||||
|
@ -79,6 +79,9 @@ pub enum Error {
|
||||
nth_shape: Shape,
|
||||
},
|
||||
|
||||
#[error("{op} can only be performed on a single dimension")]
|
||||
OnlySingleDimension { op: &'static str, dims: Vec<usize> },
|
||||
|
||||
#[error("empty tensor for {op}")]
|
||||
EmptyTensor { op: &'static str },
|
||||
|
||||
|
@ -17,6 +17,20 @@ pub enum ReduceOp {
|
||||
Sum,
|
||||
Min,
|
||||
Max,
|
||||
ArgMin,
|
||||
ArgMax,
|
||||
}
|
||||
|
||||
impl ReduceOp {
|
||||
pub(crate) fn name(&self) -> &'static str {
|
||||
match self {
|
||||
Self::ArgMax => "argmax",
|
||||
Self::ArgMin => "argmin",
|
||||
Self::Min => "min",
|
||||
Self::Max => "max",
|
||||
Self::Sum => "sum",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// These ops return the same type as their input type.
|
||||
|
@ -628,47 +628,21 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
fn max_impl<D: Dims>(&self, max_dims: D, keepdim: bool) -> Result<Self> {
|
||||
let max_dims = max_dims.to_indexes(self.shape(), "max")?;
|
||||
let storage = self
|
||||
.storage()
|
||||
.reduce_op(ReduceOp::Max, self.layout(), &max_dims)?;
|
||||
fn reduce_impl<D: Dim>(&self, dim: D, keepdim: bool, op: ReduceOp) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), op.name())?;
|
||||
let storage = self.storage().reduce_op(op, self.layout(), &[dim])?;
|
||||
let mut dims = self.dims().to_vec();
|
||||
for &max_dim in max_dims.iter() {
|
||||
dims[max_dim] = 1
|
||||
}
|
||||
dims[dim] = 1;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Reduce(self.clone(), ReduceOp::Max, dims.to_vec()))
|
||||
Some(Op::Reduce(self.clone(), op, dims.to_vec()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let max = from_storage(storage, dims, op, false);
|
||||
let res = from_storage(storage, dims, op, false);
|
||||
if keepdim {
|
||||
Ok(max)
|
||||
Ok(res)
|
||||
} else {
|
||||
max.squeeze_dims(&max_dims)
|
||||
}
|
||||
}
|
||||
|
||||
fn min_impl<D: Dims>(&self, min_dims: D, keepdim: bool) -> Result<Self> {
|
||||
let min_dims = min_dims.to_indexes(self.shape(), "min")?;
|
||||
let storage = self
|
||||
.storage()
|
||||
.reduce_op(ReduceOp::Min, self.layout(), &min_dims)?;
|
||||
let mut dims = self.dims().to_vec();
|
||||
for &min_dim in min_dims.iter() {
|
||||
dims[min_dim] = 1
|
||||
}
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Reduce(self.clone(), ReduceOp::Min, dims.to_vec()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let min = from_storage(storage, dims, op, false);
|
||||
if keepdim {
|
||||
Ok(min)
|
||||
} else {
|
||||
min.squeeze_dims(&min_dims)
|
||||
res.squeeze_dims(&[dim])
|
||||
}
|
||||
}
|
||||
|
||||
@ -722,30 +696,36 @@ impl Tensor {
|
||||
self.sum_impl(sum_dims, false)
|
||||
}
|
||||
|
||||
pub fn max_keepdim<D: Dims>(&self, max_dims: D) -> Result<Self> {
|
||||
self.max_impl(max_dims, true)
|
||||
pub fn max_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
self.reduce_impl(dim, true, ReduceOp::Max)
|
||||
}
|
||||
|
||||
pub fn max<D: Dims>(&self, max_dims: D) -> Result<Self> {
|
||||
self.max_impl(max_dims, false)
|
||||
pub fn max<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
self.reduce_impl(dim, false, ReduceOp::Max)
|
||||
}
|
||||
|
||||
pub fn max_all(&self) -> Result<Tensor> {
|
||||
let dims: Vec<_> = (0..self.rank()).collect();
|
||||
self.max(dims)
|
||||
pub fn min_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
self.reduce_impl(dim, true, ReduceOp::Min)
|
||||
}
|
||||
|
||||
pub fn min_keepdim<D: Dims>(&self, min_dims: D) -> Result<Self> {
|
||||
self.min_impl(min_dims, true)
|
||||
pub fn min<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
self.reduce_impl(dim, false, ReduceOp::Min)
|
||||
}
|
||||
|
||||
pub fn min<D: Dims>(&self, min_dims: D) -> Result<Self> {
|
||||
self.min_impl(min_dims, false)
|
||||
pub fn argmax_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
self.reduce_impl(dim, true, ReduceOp::ArgMax)
|
||||
}
|
||||
|
||||
pub fn min_all(&self) -> Result<Tensor> {
|
||||
let dims: Vec<_> = (0..self.rank()).collect();
|
||||
self.min(dims)
|
||||
pub fn argmax<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
self.reduce_impl(dim, false, ReduceOp::ArgMax)
|
||||
}
|
||||
|
||||
pub fn argmin_keepdim<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
self.reduce_impl(dim, true, ReduceOp::ArgMin)
|
||||
}
|
||||
|
||||
pub fn argmin<D: Dim>(&self, dim: D) -> Result<Self> {
|
||||
self.reduce_impl(dim, false, ReduceOp::ArgMin)
|
||||
}
|
||||
|
||||
pub fn cmp(&self, rhs: &Self, op: CmpOp) -> Result<Self> {
|
||||
|
@ -42,7 +42,7 @@ pub fn main() -> Result<()> {
|
||||
let bs = Var::zeros(LABELS, DType::F32, &dev)?;
|
||||
let sgd = candle_nn::SGD::new(&[&ws, &bs], 1.0);
|
||||
let test_images = m.test_images;
|
||||
let test_labels = m.test_labels.to_vec1::<u8>()?;
|
||||
let test_labels = m.test_labels.to_dtype(DType::U32)?;
|
||||
for epoch in 1..200 {
|
||||
let logits = train_images.matmul(&ws)?.broadcast_add(&bs)?;
|
||||
let log_sm = log_softmax(&logits, D::Minus1)?;
|
||||
@ -52,28 +52,13 @@ pub fn main() -> Result<()> {
|
||||
sgd.backward_step(&loss)?;
|
||||
|
||||
let test_logits = test_images.matmul(&ws)?.broadcast_add(&bs)?;
|
||||
/* TODO: Add argmax so that the following can be computed within candle.
|
||||
let test_accuracy = test_logits
|
||||
.argmax(Some(-1), false)
|
||||
.eq_tensor(&test_labels)
|
||||
.to_kind(Kind::Float)
|
||||
.mean(Kind::Float)
|
||||
.double_value(&[]);
|
||||
*/
|
||||
let test_logits = test_logits.to_vec2::<f32>()?;
|
||||
let sum_ok = test_logits
|
||||
.iter()
|
||||
.zip(test_labels.iter())
|
||||
.map(|(logits, label)| {
|
||||
let arg_max = logits
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, v1), (_, v2)| v1.total_cmp(v2))
|
||||
.map(|(idx, _)| idx);
|
||||
f64::from(arg_max == Some(*label as usize))
|
||||
})
|
||||
.sum::<f64>();
|
||||
let test_accuracy = sum_ok / test_labels.len() as f64;
|
||||
.argmax(D::Minus1)?
|
||||
.eq(&test_labels)?
|
||||
.to_dtype(DType::F32)?
|
||||
.sum_all()?
|
||||
.to_scalar::<f32>()?;
|
||||
let test_accuracy = sum_ok / test_labels.shape().r1()? as f32;
|
||||
println!(
|
||||
"{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
|
||||
loss.to_scalar::<f32>()?,
|
||||
|
Reference in New Issue
Block a user