mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Prepare for the custom-op extension. (#1892)
This commit is contained in:
@ -1,9 +1,7 @@
|
||||
//! Tensors are N-dimensional matrixes of elements using a single data type.
|
||||
#![allow(clippy::redundant_closure_call)]
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{
|
||||
BackpropOp, BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp,
|
||||
};
|
||||
use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
|
||||
use crate::scalar::TensorOrScalar;
|
||||
use crate::shape::{Dim, Dims};
|
||||
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
@ -2277,96 +2275,6 @@ impl Tensor {
|
||||
std::ptr::eq(lhs, rhs)
|
||||
}
|
||||
|
||||
/// Applies a unary custom op without backward support
|
||||
pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self> {
|
||||
let (storage, shape) = self.storage().apply_op1(self.layout(), c)?;
|
||||
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
||||
}
|
||||
|
||||
/// Applies a binary custom op without backward support
|
||||
pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self> {
|
||||
let (storage, shape) =
|
||||
self.storage()
|
||||
.apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?;
|
||||
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
||||
}
|
||||
|
||||
/// Applies a ternary custom op without backward support
|
||||
pub fn apply_op3_no_bwd<C: CustomOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<Self> {
|
||||
let (storage, shape) = self.storage().apply_op3(
|
||||
self.layout(),
|
||||
&t2.storage(),
|
||||
t2.layout(),
|
||||
&t3.storage(),
|
||||
t3.layout(),
|
||||
c,
|
||||
)?;
|
||||
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
||||
}
|
||||
|
||||
/// Applies a unary custom op.
|
||||
pub fn apply_op1_arc(&self, c: Arc<Box<dyn CustomOp1 + Send + Sync>>) -> Result<Self> {
|
||||
let (storage, shape) = self
|
||||
.storage()
|
||||
.apply_op1(self.layout(), c.as_ref().as_ref())?;
|
||||
let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>(&self, c: C) -> Result<Self> {
|
||||
self.apply_op1_arc(Arc::new(Box::new(c)))
|
||||
}
|
||||
|
||||
/// Applies a binary custom op.
|
||||
pub fn apply_op2_arc(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
c: Arc<Box<dyn CustomOp2 + Send + Sync>>,
|
||||
) -> Result<Self> {
|
||||
let (storage, shape) = self.storage().apply_op2(
|
||||
self.layout(),
|
||||
&rhs.storage(),
|
||||
rhs.layout(),
|
||||
c.as_ref().as_ref(),
|
||||
)?;
|
||||
let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone()));
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>(&self, r: &Self, c: C) -> Result<Self> {
|
||||
self.apply_op2_arc(r, Arc::new(Box::new(c)))
|
||||
}
|
||||
|
||||
/// Applies a ternary custom op.
|
||||
pub fn apply_op3_arc(
|
||||
&self,
|
||||
t2: &Self,
|
||||
t3: &Self,
|
||||
c: Arc<Box<dyn CustomOp3 + Send + Sync>>,
|
||||
) -> Result<Self> {
|
||||
let (storage, shape) = self.storage().apply_op3(
|
||||
self.layout(),
|
||||
&t2.storage(),
|
||||
t2.layout(),
|
||||
&t3.storage(),
|
||||
t3.layout(),
|
||||
c.as_ref().as_ref(),
|
||||
)?;
|
||||
let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| {
|
||||
Op::CustomOp3(t1, t2, t3, c.clone())
|
||||
});
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
pub fn apply_op3<C: 'static + CustomOp3 + Send + Sync>(
|
||||
&self,
|
||||
t2: &Self,
|
||||
t3: &Self,
|
||||
c: C,
|
||||
) -> Result<Self> {
|
||||
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
|
||||
}
|
||||
|
||||
/// Normalize a 'relative' axis value: positive values are kept, negative
|
||||
/// values means counting the dimensions from the back.
|
||||
pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
|
||||
|
Reference in New Issue
Block a user