mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Prepare for the custom-op extension. (#1892)
This commit is contained in:
244
candle-core/src/custom_op.rs
Normal file
244
candle-core/src/custom_op.rs
Normal file
@ -0,0 +1,244 @@
|
|||||||
|
use crate::op::{BackpropOp, Op};
|
||||||
|
use crate::tensor::from_storage;
|
||||||
|
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
/// Unary ops that can be defined in user-land.
|
||||||
|
pub trait CustomOp1 {
|
||||||
|
// Box<dyn> does not support const yet, so use a function to get the name.
|
||||||
|
fn name(&self) -> &'static str;
|
||||||
|
|
||||||
|
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>;
|
||||||
|
|
||||||
|
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> {
|
||||||
|
Err(crate::Error::Cuda(
|
||||||
|
format!("no cuda implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn metal_fwd(
|
||||||
|
&self,
|
||||||
|
_storage: &MetalStorage,
|
||||||
|
_layout: &Layout,
|
||||||
|
) -> Result<(MetalStorage, Shape)> {
|
||||||
|
Err(crate::Error::Metal(
|
||||||
|
format!("no metal implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This function takes as argument the argument `arg` used in the forward pass, the result
|
||||||
|
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
|
||||||
|
/// The function should return the gradient of the argument.
|
||||||
|
fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Option<Tensor>> {
|
||||||
|
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait CustomOp2 {
|
||||||
|
fn name(&self) -> &'static str;
|
||||||
|
|
||||||
|
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn cpu_fwd(
|
||||||
|
&self,
|
||||||
|
s1: &CpuStorage,
|
||||||
|
l1: &Layout,
|
||||||
|
s2: &CpuStorage,
|
||||||
|
l2: &Layout,
|
||||||
|
) -> Result<(CpuStorage, Shape)>;
|
||||||
|
|
||||||
|
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn cuda_fwd(
|
||||||
|
&self,
|
||||||
|
_: &CudaStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &CudaStorage,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<(CudaStorage, Shape)> {
|
||||||
|
Err(crate::Error::Cuda(
|
||||||
|
format!("no cuda implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn metal_fwd(
|
||||||
|
&self,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<(MetalStorage, Shape)> {
|
||||||
|
Err(crate::Error::Metal(
|
||||||
|
format!("no metal implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bwd(
|
||||||
|
&self,
|
||||||
|
_arg1: &Tensor,
|
||||||
|
_arg2: &Tensor,
|
||||||
|
_res: &Tensor,
|
||||||
|
_grad_res: &Tensor,
|
||||||
|
) -> Result<(Option<Tensor>, Option<Tensor>)> {
|
||||||
|
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait CustomOp3 {
|
||||||
|
fn name(&self) -> &'static str;
|
||||||
|
|
||||||
|
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn cpu_fwd(
|
||||||
|
&self,
|
||||||
|
s1: &CpuStorage,
|
||||||
|
l1: &Layout,
|
||||||
|
s2: &CpuStorage,
|
||||||
|
l2: &Layout,
|
||||||
|
s3: &CpuStorage,
|
||||||
|
l3: &Layout,
|
||||||
|
) -> Result<(CpuStorage, Shape)>;
|
||||||
|
|
||||||
|
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn cuda_fwd(
|
||||||
|
&self,
|
||||||
|
_: &CudaStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &CudaStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &CudaStorage,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<(CudaStorage, Shape)> {
|
||||||
|
Err(crate::Error::Cuda(
|
||||||
|
format!("no cuda implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
||||||
|
/// offsets etc so the associated layout should be used to access it.
|
||||||
|
fn metal_fwd(
|
||||||
|
&self,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
_: &MetalStorage,
|
||||||
|
_: &Layout,
|
||||||
|
) -> Result<(MetalStorage, Shape)> {
|
||||||
|
Err(crate::Error::Metal(
|
||||||
|
format!("no metal implementation for {}", self.name()).into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bwd(
|
||||||
|
&self,
|
||||||
|
_arg1: &Tensor,
|
||||||
|
_arg2: &Tensor,
|
||||||
|
_arg3: &Tensor,
|
||||||
|
_res: &Tensor,
|
||||||
|
_grad_res: &Tensor,
|
||||||
|
) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> {
|
||||||
|
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Tensor {
|
||||||
|
/// Applies a unary custom op without backward support
|
||||||
|
pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self> {
|
||||||
|
let (storage, shape) = self.storage().apply_op1(self.layout(), c)?;
|
||||||
|
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Applies a binary custom op without backward support
|
||||||
|
pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self> {
|
||||||
|
let (storage, shape) =
|
||||||
|
self.storage()
|
||||||
|
.apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?;
|
||||||
|
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Applies a ternary custom op without backward support
|
||||||
|
pub fn apply_op3_no_bwd<C: CustomOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<Self> {
|
||||||
|
let (storage, shape) = self.storage().apply_op3(
|
||||||
|
self.layout(),
|
||||||
|
&t2.storage(),
|
||||||
|
t2.layout(),
|
||||||
|
&t3.storage(),
|
||||||
|
t3.layout(),
|
||||||
|
c,
|
||||||
|
)?;
|
||||||
|
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Applies a unary custom op.
|
||||||
|
pub fn apply_op1_arc(&self, c: Arc<Box<dyn CustomOp1 + Send + Sync>>) -> Result<Self> {
|
||||||
|
let (storage, shape) = self
|
||||||
|
.storage()
|
||||||
|
.apply_op1(self.layout(), c.as_ref().as_ref())?;
|
||||||
|
let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
|
||||||
|
Ok(from_storage(storage, shape, op, false))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>(&self, c: C) -> Result<Self> {
|
||||||
|
self.apply_op1_arc(Arc::new(Box::new(c)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Applies a binary custom op.
|
||||||
|
pub fn apply_op2_arc(
|
||||||
|
&self,
|
||||||
|
rhs: &Self,
|
||||||
|
c: Arc<Box<dyn CustomOp2 + Send + Sync>>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let (storage, shape) = self.storage().apply_op2(
|
||||||
|
self.layout(),
|
||||||
|
&rhs.storage(),
|
||||||
|
rhs.layout(),
|
||||||
|
c.as_ref().as_ref(),
|
||||||
|
)?;
|
||||||
|
let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone()));
|
||||||
|
Ok(from_storage(storage, shape, op, false))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>(&self, r: &Self, c: C) -> Result<Self> {
|
||||||
|
self.apply_op2_arc(r, Arc::new(Box::new(c)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Applies a ternary custom op.
|
||||||
|
pub fn apply_op3_arc(
|
||||||
|
&self,
|
||||||
|
t2: &Self,
|
||||||
|
t3: &Self,
|
||||||
|
c: Arc<Box<dyn CustomOp3 + Send + Sync>>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let (storage, shape) = self.storage().apply_op3(
|
||||||
|
self.layout(),
|
||||||
|
&t2.storage(),
|
||||||
|
t2.layout(),
|
||||||
|
&t3.storage(),
|
||||||
|
t3.layout(),
|
||||||
|
c.as_ref().as_ref(),
|
||||||
|
)?;
|
||||||
|
let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| {
|
||||||
|
Op::CustomOp3(t1, t2, t3, c.clone())
|
||||||
|
});
|
||||||
|
Ok(from_storage(storage, shape, op, false))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn apply_op3<C: 'static + CustomOp3 + Send + Sync>(
|
||||||
|
&self,
|
||||||
|
t2: &Self,
|
||||||
|
t3: &Self,
|
||||||
|
c: C,
|
||||||
|
) -> Result<Self> {
|
||||||
|
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
|
||||||
|
}
|
||||||
|
}
|
@ -45,6 +45,7 @@ pub mod cpu_backend;
|
|||||||
pub mod cuda_backend;
|
pub mod cuda_backend;
|
||||||
#[cfg(feature = "cudnn")]
|
#[cfg(feature = "cudnn")]
|
||||||
pub mod cudnn;
|
pub mod cudnn;
|
||||||
|
mod custom_op;
|
||||||
mod device;
|
mod device;
|
||||||
pub mod display;
|
pub mod display;
|
||||||
mod dtype;
|
mod dtype;
|
||||||
@ -73,12 +74,12 @@ pub mod utils;
|
|||||||
mod variable;
|
mod variable;
|
||||||
|
|
||||||
pub use cpu_backend::CpuStorage;
|
pub use cpu_backend::CpuStorage;
|
||||||
|
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3};
|
||||||
pub use device::{Device, DeviceLocation, NdArray};
|
pub use device::{Device, DeviceLocation, NdArray};
|
||||||
pub use dtype::{DType, FloatDType, IntDType, WithDType};
|
pub use dtype::{DType, FloatDType, IntDType, WithDType};
|
||||||
pub use error::{Error, Result};
|
pub use error::{Error, Result};
|
||||||
pub use indexer::IndexOp;
|
pub use indexer::IndexOp;
|
||||||
pub use layout::Layout;
|
pub use layout::Layout;
|
||||||
pub use op::{CustomOp1, CustomOp2, CustomOp3};
|
|
||||||
pub use shape::{Shape, D};
|
pub use shape::{Shape, D};
|
||||||
pub use storage::Storage;
|
pub use storage::Storage;
|
||||||
pub use strided_index::{StridedBlocks, StridedIndex};
|
pub use strided_index::{StridedBlocks, StridedIndex};
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
#![allow(clippy::redundant_closure_call)]
|
#![allow(clippy::redundant_closure_call)]
|
||||||
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
|
use crate::Tensor;
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use num_traits::float::Float;
|
use num_traits::float::Float;
|
||||||
|
|
||||||
@ -161,168 +161,23 @@ pub enum Op {
|
|||||||
Permute(Tensor, Vec<usize>),
|
Permute(Tensor, Vec<usize>),
|
||||||
Elu(Tensor, f64),
|
Elu(Tensor, f64),
|
||||||
Powf(Tensor, f64),
|
Powf(Tensor, f64),
|
||||||
CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1 + Send + Sync>>),
|
CustomOp1(
|
||||||
|
Tensor,
|
||||||
|
std::sync::Arc<Box<dyn crate::CustomOp1 + Send + Sync>>,
|
||||||
|
),
|
||||||
CustomOp2(
|
CustomOp2(
|
||||||
Tensor,
|
Tensor,
|
||||||
Tensor,
|
Tensor,
|
||||||
std::sync::Arc<Box<dyn CustomOp2 + Send + Sync>>,
|
std::sync::Arc<Box<dyn crate::CustomOp2 + Send + Sync>>,
|
||||||
),
|
),
|
||||||
CustomOp3(
|
CustomOp3(
|
||||||
Tensor,
|
Tensor,
|
||||||
Tensor,
|
Tensor,
|
||||||
Tensor,
|
Tensor,
|
||||||
std::sync::Arc<Box<dyn CustomOp3 + Send + Sync>>,
|
std::sync::Arc<Box<dyn crate::CustomOp3 + Send + Sync>>,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Unary ops that can be defined in user-land.
|
|
||||||
pub trait CustomOp1 {
|
|
||||||
// Box<dyn> does not support const yet, so use a function to get the name.
|
|
||||||
fn name(&self) -> &'static str;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)>;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cuda_fwd(&self, _storage: &CudaStorage, _layout: &Layout) -> Result<(CudaStorage, Shape)> {
|
|
||||||
Err(crate::Error::Cuda(
|
|
||||||
format!("no cuda implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn metal_fwd(
|
|
||||||
&self,
|
|
||||||
_storage: &MetalStorage,
|
|
||||||
_layout: &Layout,
|
|
||||||
) -> Result<(MetalStorage, Shape)> {
|
|
||||||
Err(crate::Error::Metal(
|
|
||||||
format!("no metal implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// This function takes as argument the argument `arg` used in the forward pass, the result
|
|
||||||
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
|
|
||||||
/// The function should return the gradient of the argument.
|
|
||||||
fn bwd(&self, _arg: &Tensor, _res: &Tensor, _grad_res: &Tensor) -> Result<Option<Tensor>> {
|
|
||||||
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait CustomOp2 {
|
|
||||||
fn name(&self) -> &'static str;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cpu_fwd(
|
|
||||||
&self,
|
|
||||||
s1: &CpuStorage,
|
|
||||||
l1: &Layout,
|
|
||||||
s2: &CpuStorage,
|
|
||||||
l2: &Layout,
|
|
||||||
) -> Result<(CpuStorage, Shape)>;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cuda_fwd(
|
|
||||||
&self,
|
|
||||||
_: &CudaStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &CudaStorage,
|
|
||||||
_: &Layout,
|
|
||||||
) -> Result<(CudaStorage, Shape)> {
|
|
||||||
Err(crate::Error::Cuda(
|
|
||||||
format!("no cuda implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn metal_fwd(
|
|
||||||
&self,
|
|
||||||
_: &MetalStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &MetalStorage,
|
|
||||||
_: &Layout,
|
|
||||||
) -> Result<(MetalStorage, Shape)> {
|
|
||||||
Err(crate::Error::Metal(
|
|
||||||
format!("no metal implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn bwd(
|
|
||||||
&self,
|
|
||||||
_arg1: &Tensor,
|
|
||||||
_arg2: &Tensor,
|
|
||||||
_res: &Tensor,
|
|
||||||
_grad_res: &Tensor,
|
|
||||||
) -> Result<(Option<Tensor>, Option<Tensor>)> {
|
|
||||||
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait CustomOp3 {
|
|
||||||
fn name(&self) -> &'static str;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cpu_fwd(
|
|
||||||
&self,
|
|
||||||
s1: &CpuStorage,
|
|
||||||
l1: &Layout,
|
|
||||||
s2: &CpuStorage,
|
|
||||||
l2: &Layout,
|
|
||||||
s3: &CpuStorage,
|
|
||||||
l3: &Layout,
|
|
||||||
) -> Result<(CpuStorage, Shape)>;
|
|
||||||
|
|
||||||
/// The forward pass, as run on a gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn cuda_fwd(
|
|
||||||
&self,
|
|
||||||
_: &CudaStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &CudaStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &CudaStorage,
|
|
||||||
_: &Layout,
|
|
||||||
) -> Result<(CudaStorage, Shape)> {
|
|
||||||
Err(crate::Error::Cuda(
|
|
||||||
format!("no cuda implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
|
|
||||||
/// offsets etc so the associated layout should be used to access it.
|
|
||||||
fn metal_fwd(
|
|
||||||
&self,
|
|
||||||
_: &MetalStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &MetalStorage,
|
|
||||||
_: &Layout,
|
|
||||||
_: &MetalStorage,
|
|
||||||
_: &Layout,
|
|
||||||
) -> Result<(MetalStorage, Shape)> {
|
|
||||||
Err(crate::Error::Metal(
|
|
||||||
format!("no metal implementation for {}", self.name()).into(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn bwd(
|
|
||||||
&self,
|
|
||||||
_arg1: &Tensor,
|
|
||||||
_arg2: &Tensor,
|
|
||||||
_arg3: &Tensor,
|
|
||||||
_res: &Tensor,
|
|
||||||
_grad_res: &Tensor,
|
|
||||||
) -> Result<(Option<Tensor>, Option<Tensor>, Option<Tensor>)> {
|
|
||||||
Err(crate::Error::BackwardNotSupported { op: self.name() })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait UnaryOpT {
|
pub trait UnaryOpT {
|
||||||
const NAME: &'static str;
|
const NAME: &'static str;
|
||||||
const KERNEL: &'static str;
|
const KERNEL: &'static str;
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
use crate::backend::BackendStorage;
|
use crate::backend::BackendStorage;
|
||||||
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
|
use crate::op::{self, CmpOp, ReduceOp};
|
||||||
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
|
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
|
||||||
|
use crate::{CustomOp1, CustomOp2, CustomOp3};
|
||||||
|
|
||||||
// We do not want to implement Clone on Storage as cloning may fail because of
|
// We do not want to implement Clone on Storage as cloning may fail because of
|
||||||
// out of memory. Instead try_clone should be used.
|
// out of memory. Instead try_clone should be used.
|
||||||
|
@ -1,9 +1,7 @@
|
|||||||
//! Tensors are N-dimensional matrixes of elements using a single data type.
|
//! Tensors are N-dimensional matrixes of elements using a single data type.
|
||||||
#![allow(clippy::redundant_closure_call)]
|
#![allow(clippy::redundant_closure_call)]
|
||||||
use crate::backend::{BackendDevice, BackendStorage};
|
use crate::backend::{BackendDevice, BackendStorage};
|
||||||
use crate::op::{
|
use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
|
||||||
BackpropOp, BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp,
|
|
||||||
};
|
|
||||||
use crate::scalar::TensorOrScalar;
|
use crate::scalar::TensorOrScalar;
|
||||||
use crate::shape::{Dim, Dims};
|
use crate::shape::{Dim, Dims};
|
||||||
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||||
@ -2277,96 +2275,6 @@ impl Tensor {
|
|||||||
std::ptr::eq(lhs, rhs)
|
std::ptr::eq(lhs, rhs)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Applies a unary custom op without backward support
|
|
||||||
pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self> {
|
|
||||||
let (storage, shape) = self.storage().apply_op1(self.layout(), c)?;
|
|
||||||
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Applies a binary custom op without backward support
|
|
||||||
pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self> {
|
|
||||||
let (storage, shape) =
|
|
||||||
self.storage()
|
|
||||||
.apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?;
|
|
||||||
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Applies a ternary custom op without backward support
|
|
||||||
pub fn apply_op3_no_bwd<C: CustomOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<Self> {
|
|
||||||
let (storage, shape) = self.storage().apply_op3(
|
|
||||||
self.layout(),
|
|
||||||
&t2.storage(),
|
|
||||||
t2.layout(),
|
|
||||||
&t3.storage(),
|
|
||||||
t3.layout(),
|
|
||||||
c,
|
|
||||||
)?;
|
|
||||||
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Applies a unary custom op.
|
|
||||||
pub fn apply_op1_arc(&self, c: Arc<Box<dyn CustomOp1 + Send + Sync>>) -> Result<Self> {
|
|
||||||
let (storage, shape) = self
|
|
||||||
.storage()
|
|
||||||
.apply_op1(self.layout(), c.as_ref().as_ref())?;
|
|
||||||
let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
|
|
||||||
Ok(from_storage(storage, shape, op, false))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>(&self, c: C) -> Result<Self> {
|
|
||||||
self.apply_op1_arc(Arc::new(Box::new(c)))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Applies a binary custom op.
|
|
||||||
pub fn apply_op2_arc(
|
|
||||||
&self,
|
|
||||||
rhs: &Self,
|
|
||||||
c: Arc<Box<dyn CustomOp2 + Send + Sync>>,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let (storage, shape) = self.storage().apply_op2(
|
|
||||||
self.layout(),
|
|
||||||
&rhs.storage(),
|
|
||||||
rhs.layout(),
|
|
||||||
c.as_ref().as_ref(),
|
|
||||||
)?;
|
|
||||||
let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone()));
|
|
||||||
Ok(from_storage(storage, shape, op, false))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>(&self, r: &Self, c: C) -> Result<Self> {
|
|
||||||
self.apply_op2_arc(r, Arc::new(Box::new(c)))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Applies a ternary custom op.
|
|
||||||
pub fn apply_op3_arc(
|
|
||||||
&self,
|
|
||||||
t2: &Self,
|
|
||||||
t3: &Self,
|
|
||||||
c: Arc<Box<dyn CustomOp3 + Send + Sync>>,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let (storage, shape) = self.storage().apply_op3(
|
|
||||||
self.layout(),
|
|
||||||
&t2.storage(),
|
|
||||||
t2.layout(),
|
|
||||||
&t3.storage(),
|
|
||||||
t3.layout(),
|
|
||||||
c.as_ref().as_ref(),
|
|
||||||
)?;
|
|
||||||
let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| {
|
|
||||||
Op::CustomOp3(t1, t2, t3, c.clone())
|
|
||||||
});
|
|
||||||
Ok(from_storage(storage, shape, op, false))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn apply_op3<C: 'static + CustomOp3 + Send + Sync>(
|
|
||||||
&self,
|
|
||||||
t2: &Self,
|
|
||||||
t3: &Self,
|
|
||||||
c: C,
|
|
||||||
) -> Result<Self> {
|
|
||||||
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Normalize a 'relative' axis value: positive values are kept, negative
|
/// Normalize a 'relative' axis value: positive values are kept, negative
|
||||||
/// values means counting the dimensions from the back.
|
/// values means counting the dimensions from the back.
|
||||||
pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
|
pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
|
||||||
|
Reference in New Issue
Block a user