mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add cpu support for min and max. (#202)
* Add cpu support for min and max. * Add min/max all.
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{Op, ReduceOp};
|
||||
use crate::shape::{Dim, Dims};
|
||||
use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
/// Unique identifier for tensors.
|
||||
@ -629,9 +630,9 @@ impl Tensor {
|
||||
|
||||
fn max_impl<D: Dims>(&self, max_dims: D, keepdim: bool) -> Result<Self> {
|
||||
let max_dims = max_dims.to_indexes(self.shape(), "max")?;
|
||||
let storage =
|
||||
self.storage()
|
||||
.reduce_op(crate::op::ReduceOp::Max, self.layout(), &max_dims)?;
|
||||
let storage = self
|
||||
.storage()
|
||||
.reduce_op(ReduceOp::Max, self.layout(), &max_dims)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Max(self.clone(), max_dims.to_vec()))
|
||||
} else {
|
||||
@ -651,9 +652,9 @@ impl Tensor {
|
||||
|
||||
fn min_impl<D: Dims>(&self, min_dims: D, keepdim: bool) -> Result<Self> {
|
||||
let min_dims = min_dims.to_indexes(self.shape(), "min")?;
|
||||
let storage =
|
||||
self.storage()
|
||||
.reduce_op(crate::op::ReduceOp::Min, self.layout(), &min_dims)?;
|
||||
let storage = self
|
||||
.storage()
|
||||
.reduce_op(ReduceOp::Min, self.layout(), &min_dims)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Min(self.clone(), min_dims.to_vec()))
|
||||
} else {
|
||||
@ -673,9 +674,9 @@ impl Tensor {
|
||||
|
||||
fn sum_impl<D: Dims>(&self, sum_dims: D, keepdim: bool) -> Result<Self> {
|
||||
let sum_dims = sum_dims.to_indexes(self.shape(), "sum")?;
|
||||
let storage =
|
||||
self.storage()
|
||||
.reduce_op(crate::op::ReduceOp::Sum, self.layout(), &sum_dims)?;
|
||||
let storage = self
|
||||
.storage()
|
||||
.reduce_op(ReduceOp::Sum, self.layout(), &sum_dims)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Sum(self.clone(), sum_dims.to_vec()))
|
||||
} else {
|
||||
@ -729,6 +730,11 @@ impl Tensor {
|
||||
self.max_impl(max_dims, false)
|
||||
}
|
||||
|
||||
pub fn max_all(&self) -> Result<Tensor> {
|
||||
let dims: Vec<_> = (0..self.rank()).collect();
|
||||
self.max(dims)
|
||||
}
|
||||
|
||||
pub fn min_keepdim<D: Dims>(&self, min_dims: D) -> Result<Self> {
|
||||
self.min_impl(min_dims, true)
|
||||
}
|
||||
@ -737,6 +743,11 @@ impl Tensor {
|
||||
self.min_impl(min_dims, false)
|
||||
}
|
||||
|
||||
pub fn min_all(&self) -> Result<Tensor> {
|
||||
let dims: Vec<_> = (0..self.rank()).collect();
|
||||
self.min(dims)
|
||||
}
|
||||
|
||||
/// Applies a 1D convolution over the input tensor.
|
||||
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
|
||||
let (c_out, c_in_k, k_size) = kernel.shape().r3()?;
|
||||
|
Reference in New Issue
Block a user