mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add the comparison operations. (#207)
* Add the comparison operations. * Add the helper functions on the tensor side. * More cmp operations. * Cpu implementation for the comparison operations.
This commit is contained in:
@ -1,4 +1,5 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{CmpOp, ReduceOp};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape, WithDType};
|
||||
use candle_kernels as kernels;
|
||||
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||
@ -515,7 +516,7 @@ impl<'a> Map1 for Sum<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
struct FastReduce<'a>(&'a [usize], crate::op::ReduceOp);
|
||||
struct FastReduce<'a>(&'a [usize], ReduceOp);
|
||||
impl<'a> Map1 for FastReduce<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
@ -558,9 +559,9 @@ impl<'a> Map1 for FastReduce<'a> {
|
||||
.w()?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let name = match self.1 {
|
||||
crate::op::ReduceOp::Sum => "fast_sum",
|
||||
crate::op::ReduceOp::Min => "fast_min",
|
||||
crate::op::ReduceOp::Max => "fast_max",
|
||||
ReduceOp::Sum => "fast_sum",
|
||||
ReduceOp::Min => "fast_min",
|
||||
ReduceOp::Max => "fast_max",
|
||||
};
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::REDUCE)?;
|
||||
// SAFETY: filled in by the follow up kernel.
|
||||
@ -961,17 +962,16 @@ impl BackendStorage for CudaStorage {
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn reduce_op(
|
||||
&self,
|
||||
op: crate::op::ReduceOp,
|
||||
layout: &Layout,
|
||||
sum_dims: &[usize],
|
||||
) -> Result<Self> {
|
||||
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = FastReduce(sum_dims, op).map(&self.slice, &device, layout)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||
Err(CudaError::InternalError("TODO: implement cmp").into())
|
||||
}
|
||||
|
||||
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
|
||||
Err(CudaError::InternalError("TODO: implement divide_by_sum_over_dim").into())
|
||||
}
|
||||
|
Reference in New Issue
Block a user