mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add binary and ternary custom ops. (#217)
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::op::{self, CmpOp, CustomOp1, ReduceOp};
|
||||
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
|
||||
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
|
||||
|
||||
// We do not want to implement Clone on Storage as cloning may fail because of
|
||||
@ -149,7 +149,7 @@ impl Storage {
|
||||
|
||||
pub(crate) fn custom_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
Self::Cpu(storage) => {
|
||||
let (storage, shape) = c.cpu_fwd(storage, l)?;
|
||||
Ok((Self::Cpu(storage), shape))
|
||||
}
|
||||
@ -160,6 +160,51 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn custom_op2(
|
||||
&self,
|
||||
l1: &Layout,
|
||||
t2: &Self,
|
||||
l2: &Layout,
|
||||
c: &dyn CustomOp2,
|
||||
) -> Result<(Self, Shape)> {
|
||||
self.same_device(t2, c.name())?;
|
||||
match (self, t2) {
|
||||
(Self::Cpu(s1), Self::Cpu(s2)) => {
|
||||
let (s, shape) = c.cpu_fwd(s1, l1, s2, l2)?;
|
||||
Ok((Self::Cpu(s), shape))
|
||||
}
|
||||
(Self::Cuda(s1), Self::Cuda(s2)) => {
|
||||
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
|
||||
Ok((Self::Cuda(s), shape))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn custom_op3(
|
||||
&self,
|
||||
l1: &Layout,
|
||||
t2: &Self,
|
||||
l2: &Layout,
|
||||
t3: &Self,
|
||||
l3: &Layout,
|
||||
c: &dyn CustomOp3,
|
||||
) -> Result<(Self, Shape)> {
|
||||
self.same_device(t2, c.name())?;
|
||||
self.same_device(t3, c.name())?;
|
||||
match (self, t2, t3) {
|
||||
(Self::Cpu(s1), Self::Cpu(s2), Self::Cpu(s3)) => {
|
||||
let (s, shape) = c.cpu_fwd(s1, l1, s2, l2, s3, l3)?;
|
||||
Ok((Self::Cpu(s), shape))
|
||||
}
|
||||
(Self::Cuda(s1), Self::Cuda(s2), Self::Cuda(s3)) => {
|
||||
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
|
||||
Ok((Self::Cuda(s), shape))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||
// TODO: Different code path for the contiguous case?
|
||||
match self {
|
||||
|
Reference in New Issue
Block a user