mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Custom ops with a single argument (#214)
* Add the CustomOp1 trait. * Add an example of custom op. * Polish the custom op example. * Add some backward pass test for custom ops.
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::op::{self, CmpOp, ReduceOp};
|
||||
use crate::op::{self, CmpOp, CustomOp1, ReduceOp};
|
||||
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
|
||||
|
||||
// We do not want to implement Clone on Storage as cloning may fail because of
|
||||
@ -147,6 +147,19 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn custom_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let (storage, shape) = c.cpu_fwd(storage, l)?;
|
||||
Ok((Self::Cpu(storage), shape))
|
||||
}
|
||||
Self::Cuda(storage) => {
|
||||
let (storage, shape) = c.cuda_fwd(storage, l)?;
|
||||
Ok((Self::Cuda(storage), shape))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn unary_impl<B: op::UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||
// TODO: Different code path for the contiguous case?
|
||||
match self {
|
||||
|
Reference in New Issue
Block a user