mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Add binary and ternary custom ops. (#217)
This commit is contained in:
@ -38,7 +38,7 @@ impl Tensor {
|
||||
nodes
|
||||
} else if let Some(op) = node.op() {
|
||||
match op {
|
||||
Op::WhereCond(t1, t2, t3) => {
|
||||
Op::CustomOp3(t1, t2, t3, _) | Op::WhereCond(t1, t2, t3) => {
|
||||
let (tg, nodes) = walk(t1, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
let (tg, nodes) = walk(t2, nodes, already_seen);
|
||||
@ -52,6 +52,7 @@ impl Tensor {
|
||||
kernel: rhs,
|
||||
..
|
||||
}
|
||||
| Op::CustomOp2(lhs, rhs, _)
|
||||
| Op::Binary(lhs, rhs, _)
|
||||
| Op::IndexSelect(lhs, rhs, _)
|
||||
| Op::Embedding(lhs, rhs)
|
||||
@ -321,10 +322,38 @@ impl Tensor {
|
||||
Op::Unary(_, UnaryOp::Relu) => Err(Error::BackwardNotSupported { op: "relu" })?,
|
||||
Op::Elu(..) => Err(Error::BackwardNotSupported { op: "elu" })?,
|
||||
Op::CustomOp1(arg, c) => {
|
||||
if let Some(arg_grad) = c.bwd(arg, node, &grad)? {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let arg_grad = c.bwd(arg, node, &grad)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
}
|
||||
Op::CustomOp2(arg1, arg2, c) => {
|
||||
let (arg_grad1, arg_grad2) = c.bwd(arg1, arg2, node, &grad)?;
|
||||
if let Some(arg_grad1) = arg_grad1 {
|
||||
let sum_grad = grads.or_insert(arg1)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad1)?
|
||||
}
|
||||
if let Some(arg_grad2) = arg_grad2 {
|
||||
let sum_grad = grads.or_insert(arg2)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad2)?
|
||||
}
|
||||
}
|
||||
Op::CustomOp3(arg1, arg2, arg3, c) => {
|
||||
let (arg_grad1, arg_grad2, arg_grad3) =
|
||||
c.bwd(arg1, arg2, arg3, node, &grad)?;
|
||||
if let Some(arg_grad1) = arg_grad1 {
|
||||
let sum_grad = grads.or_insert(arg1)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad1)?
|
||||
}
|
||||
if let Some(arg_grad2) = arg_grad2 {
|
||||
let sum_grad = grads.or_insert(arg2)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad2)?
|
||||
}
|
||||
if let Some(arg_grad3) = arg_grad3 {
|
||||
let sum_grad = grads.or_insert(arg3)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad3)?
|
||||
}
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Sqr) => {
|
||||
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
|
@ -94,6 +94,8 @@ pub(crate) enum Op {
|
||||
Transpose(Tensor, usize, usize),
|
||||
Elu(Tensor, f64),
|
||||
CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1>>),
|
||||
CustomOp2(Tensor, Tensor, std::sync::Arc<Box<dyn CustomOp2>>),
|
||||
CustomOp3(Tensor, Tensor, Tensor, std::sync::Arc<Box<dyn CustomOp3>>),
|
||||
}
|
||||
|
||||
/// Unary ops that can be defined in user-land.
|
||||
@ -116,7 +118,88 @@ pub trait CustomOp1: Send + Sync {
|
||||
/// This function takes as argument the argument `arg` used in the forward pass, the result
|
||||
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
|
||||
/// The function should return the gradient of the argument.
|
||||
fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Tensor> {
|
||||
fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Option<Tensor>> {
|
||||
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
||||
}
|
||||
}
|
||||
|
||||
pub trait CustomOp2: Send + Sync {
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
s1: &CpuStorage,
|
||||
l1: &Layout,
|
||||
s2: &CpuStorage,
|
||||
l2: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)>;
|
||||
|
||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
_: &CudaStorage,
|
||||
_: &Layout,
|
||||
_: &CudaStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(CudaStorage, Shape)> {
|
||||
Err(crate::Error::Cuda(
|
||||
format!("no cuda implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
fn bwd(
|
||||
&self,
|
||||
_arg1: &Tensor,
|
||||
_arg2: &Tensor,
|
||||
_res: &Tensor,
|
||||
_grad_res: &Tensor,
|
||||
) -> Result<(Option<Tensor>, Option<Tensor>)> {
|
||||
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
||||
}
|
||||
}
|
||||
|
||||
pub trait CustomOp3: Send + Sync {
|
||||
fn name(&self) -> &'static str;
|
||||
|
||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
s1: &CpuStorage,
|
||||
l1: &Layout,
|
||||
s2: &CpuStorage,
|
||||
l2: &Layout,
|
||||
s3: &CpuStorage,
|
||||
l3: &Layout,
|
||||
) -> Result<(CpuStorage, Shape)>;
|
||||
|
||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||
/// offsets etc so the associated layout should be used to access it.
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
_: &CudaStorage,
|
||||
_: &Layout,
|
||||
_: &CudaStorage,
|
||||
_: &Layout,
|
||||
_: &CudaStorage,
|
||||
_: &Layout,
|
||||
) -> Result<(CudaStorage, Shape)> {
|
||||
Err(crate::Error::Cuda(
|
||||
format!("no cuda implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
fn bwd(
|
||||
&self,
|
||||
_arg1: &Tensor,
|
||||
_arg2: &Tensor,
|
||||
_arg3: &Tensor,
|
||||
_res: &Tensor,
|
||||
_grad_res: &Tensor,
|
||||
) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> {
|
||||
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::op::{self, CmpOp, CustomOp1, ReduceOp};
|
||||
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
|
||||
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
|
||||
|
||||
// We do not want to implement Clone on Storage as cloning may fail because of
|
||||
@ -149,7 +149,7 @@ impl Storage {
|
||||
|
||||
pub(crate) fn custom_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
Self::Cpu(storage) => {
|
||||
let (storage, shape) = c.cpu_fwd(storage, l)?;
|
||||
Ok((Self::Cpu(storage), shape))
|
||||
}
|
||||
@ -160,6 +160,51 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn custom_op2(
|
||||
&self,
|
||||
l1: &Layout,
|
||||
t2: &Self,
|
||||
l2: &Layout,
|
||||
c: &dyn CustomOp2,
|
||||
) -> Result<(Self, Shape)> {
|
||||
self.same_device(t2, c.name())?;
|
||||
match (self, t2) {
|
||||
(Self::Cpu(s1), Self::Cpu(s2)) => {
|
||||
let (s, shape) = c.cpu_fwd(s1, l1, s2, l2)?;
|
||||
Ok((Self::Cpu(s), shape))
|
||||
}
|
||||
(Self::Cuda(s1), Self::Cuda(s2)) => {
|
||||
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
|
||||
Ok((Self::Cuda(s), shape))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn custom_op3(
|
||||
&self,
|
||||
l1: &Layout,
|
||||
t2: &Self,
|
||||
l2: &Layout,
|
||||
t3: &Self,
|
||||
l3: &Layout,
|
||||
c: &dyn CustomOp3,
|
||||
) -> Result<(Self, Shape)> {
|
||||
self.same_device(t2, c.name())?;
|
||||
self.same_device(t3, c.name())?;
|
||||
match (self, t2, t3) {
|
||||
(Self::Cpu(s1), Self::Cpu(s2), Self::Cpu(s3)) => {
|
||||
let (s, shape) = c.cpu_fwd(s1, l1, s2, l2, s3, l3)?;
|
||||
Ok((Self::Cpu(s), shape))
|
||||
}
|
||||
(Self::Cuda(s1), Self::Cuda(s2), Self::Cuda(s3)) => {
|
||||
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
|
||||
Ok((Self::Cuda(s), shape))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||
// TODO: Different code path for the contiguous case?
|
||||
match self {
|
||||
|
@ -1,5 +1,5 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{BinaryOp, CmpOp, CustomOp1, Op, ReduceOp, UnaryOp};
|
||||
use crate::op::{BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp};
|
||||
use crate::shape::{Dim, Dims};
|
||||
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use std::sync::{Arc, RwLock};
|
||||
@ -1705,6 +1705,48 @@ impl Tensor {
|
||||
pub fn custom_op1<C: 'static + CustomOp1>(&self, c: C) -> Result<Self> {
|
||||
self.custom_op1_arc(Arc::new(Box::new(c)))
|
||||
}
|
||||
|
||||
/// Applies a binary custom op.
|
||||
pub fn custom_op2_arc(&self, rhs: &Self, c: Arc<Box<dyn CustomOp2>>) -> Result<Self> {
|
||||
let (storage, shape) = self.storage().custom_op2(
|
||||
self.layout(),
|
||||
&rhs.storage(),
|
||||
rhs.layout(),
|
||||
c.as_ref().as_ref(),
|
||||
)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::CustomOp2(self.clone(), rhs.clone(), c))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
pub fn custom_op2<C: 'static + CustomOp2>(&self, r: &Self, c: C) -> Result<Self> {
|
||||
self.custom_op2_arc(r, Arc::new(Box::new(c)))
|
||||
}
|
||||
|
||||
/// Applies a ternary custom op.
|
||||
pub fn custom_op3_arc(&self, t2: &Self, t3: &Self, c: Arc<Box<dyn CustomOp3>>) -> Result<Self> {
|
||||
let (storage, shape) = self.storage().custom_op3(
|
||||
self.layout(),
|
||||
&t2.storage(),
|
||||
t2.layout(),
|
||||
&t3.storage(),
|
||||
t3.layout(),
|
||||
c.as_ref().as_ref(),
|
||||
)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::CustomOp3(self.clone(), t2.clone(), t3.clone(), c))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
pub fn custom_op3<C: 'static + CustomOp3>(&self, t2: &Self, t3: &Self, c: C) -> Result<Self> {
|
||||
self.custom_op3_arc(t2, t3, Arc::new(Box::new(c)))
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! bin_trait {
|
||||
|
@ -94,10 +94,10 @@ impl CustomOp1 for EluWithBackward {
|
||||
self.0.cpu_fwd(s, l)
|
||||
}
|
||||
|
||||
fn bwd(&self, arg: &Tensor, _res: &Tensor, grad_res: &Tensor) -> Result<Tensor> {
|
||||
fn bwd(&self, arg: &Tensor, _res: &Tensor, grad_res: &Tensor) -> Result<Option<Tensor>> {
|
||||
let alpha = self.0.alpha;
|
||||
let bwd = arg.custom_op1(EluBackward { alpha })?;
|
||||
grad_res.mul(&bwd)
|
||||
Ok(Some(grad_res.mul(&bwd)?))
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user