Prepare for the custom-op extension. (#1892)

This commit is contained in:
Laurent Mazare
2024-03-21 07:02:20 +01:00
committed by GitHub
parent af7f8b87d3
commit 74b7f59261
5 changed files with 256 additions and 247 deletions

View File

@ -0,0 +1,244 @@
use crate::op::{BackpropOp, Op};
use crate::tensor::from_storage;
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
use std::sync::Arc;
/// Unary ops that can be defined in user-land.
pub trait CustomOp1 {
// Box<dyn> does not support const yet, so use a function to get the name.
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_storage: &MetalStorage,
_layout: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
/// This function takes as argument the argument `arg` used in the forward pass, the result
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
/// The function should return the gradient of the argument.
fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Option<Tensor>> {
Err(crate::Error::BackwardNotSupported { op: self.name() })
}
}
pub trait CustomOp2 {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(
&self,
s1: &CpuStorage,
l1: &Layout,
s2: &CpuStorage,
l2: &Layout,
) -> Result<(CpuStorage, Shape)>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(
&self,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
) -> Result<(CudaStorage, Shape)> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
fn bwd(
&self,
_arg1: &Tensor,
_arg2: &Tensor,
_res: &Tensor,
_grad_res: &Tensor,
) -> Result<(Option<Tensor>, Option<Tensor>)> {
Err(crate::Error::BackwardNotSupported { op: self.name() })
}
}
pub trait CustomOp3 {
fn name(&self) -> &'static str;
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cpu_fwd(
&self,
s1: &CpuStorage,
l1: &Layout,
s2: &CpuStorage,
l2: &Layout,
s3: &CpuStorage,
l3: &Layout,
) -> Result<(CpuStorage, Shape)>;
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn cuda_fwd(
&self,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
_: &CudaStorage,
_: &Layout,
) -> Result<(CudaStorage, Shape)> {
Err(crate::Error::Cuda(
format!("no cuda implementation for {}", self.name()).into(),
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
fn bwd(
&self,
_arg1: &Tensor,
_arg2: &Tensor,
_arg3: &Tensor,
_res: &Tensor,
_grad_res: &Tensor,
) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> {
Err(crate::Error::BackwardNotSupported { op: self.name() })
}
}
impl Tensor {
/// Applies a unary custom op without backward support
pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self> {
let (storage, shape) = self.storage().apply_op1(self.layout(), c)?;
Ok(from_storage(storage, shape, BackpropOp::none(), false))
}
/// Applies a binary custom op without backward support
pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self> {
let (storage, shape) =
self.storage()
.apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?;
Ok(from_storage(storage, shape, BackpropOp::none(), false))
}
/// Applies a ternary custom op without backward support
pub fn apply_op3_no_bwd<C: CustomOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<Self> {
let (storage, shape) = self.storage().apply_op3(
self.layout(),
&t2.storage(),
t2.layout(),
&t3.storage(),
t3.layout(),
c,
)?;
Ok(from_storage(storage, shape, BackpropOp::none(), false))
}
/// Applies a unary custom op.
pub fn apply_op1_arc(&self, c: Arc<Box<dyn CustomOp1 + Send + Sync>>) -> Result<Self> {
let (storage, shape) = self
.storage()
.apply_op1(self.layout(), c.as_ref().as_ref())?;
let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
Ok(from_storage(storage, shape, op, false))
}
pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>(&self, c: C) -> Result<Self> {
self.apply_op1_arc(Arc::new(Box::new(c)))
}
/// Applies a binary custom op.
pub fn apply_op2_arc(
&self,
rhs: &Self,
c: Arc<Box<dyn CustomOp2 + Send + Sync>>,
) -> Result<Self> {
let (storage, shape) = self.storage().apply_op2(
self.layout(),
&rhs.storage(),
rhs.layout(),
c.as_ref().as_ref(),
)?;
let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone()));
Ok(from_storage(storage, shape, op, false))
}
pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>(&self, r: &Self, c: C) -> Result<Self> {
self.apply_op2_arc(r, Arc::new(Box::new(c)))
}
/// Applies a ternary custom op.
pub fn apply_op3_arc(
&self,
t2: &Self,
t3: &Self,
c: Arc<Box<dyn CustomOp3 + Send + Sync>>,
) -> Result<Self> {
let (storage, shape) = self.storage().apply_op3(
self.layout(),
&t2.storage(),
t2.layout(),
&t3.storage(),
t3.layout(),
c.as_ref().as_ref(),
)?;
let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| {
Op::CustomOp3(t1, t2, t3, c.clone())
});
Ok(from_storage(storage, shape, op, false))
}
pub fn apply_op3<C: 'static + CustomOp3 + Send + Sync>(
&self,
t2: &Self,
t3: &Self,
c: C,
) -> Result<Self> {
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
}
}