mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Add an abstract backprop op type (#240)
* Start adding the backprop op type. * More backprop ops. * Finish the backprop op.
This commit is contained in:
@ -58,7 +58,7 @@ pub enum UnaryOp {
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) enum Op {
|
||||
pub enum Op {
|
||||
Binary(Tensor, Tensor, BinaryOp),
|
||||
Unary(Tensor, UnaryOp),
|
||||
Cmp(Tensor, CmpOp),
|
||||
@ -512,3 +512,63 @@ impl UnaryOpT for Relu {
|
||||
v
|
||||
}
|
||||
}
|
||||
|
||||
/// `BackpropOp` is a wrapper around `Option<Op>`. The main goal is to ensure that dependencies are
|
||||
/// properly checked when creating a new value
|
||||
#[derive(Clone)]
|
||||
pub struct BackpropOp(Option<Op>);
|
||||
|
||||
impl BackpropOp {
|
||||
pub(crate) fn none() -> Self {
|
||||
BackpropOp(None)
|
||||
}
|
||||
|
||||
pub(crate) fn new1(arg: &Tensor, f: impl Fn(Tensor) -> Op) -> Self {
|
||||
let op = if arg.track_op() {
|
||||
Some(f(arg.clone()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Self(op)
|
||||
}
|
||||
|
||||
pub(crate) fn new2(arg1: &Tensor, arg2: &Tensor, f: impl Fn(Tensor, Tensor) -> Op) -> Self {
|
||||
let op = if arg1.track_op() || arg2.track_op() {
|
||||
Some(f(arg1.clone(), arg2.clone()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Self(op)
|
||||
}
|
||||
|
||||
pub(crate) fn new3(
|
||||
arg1: &Tensor,
|
||||
arg2: &Tensor,
|
||||
arg3: &Tensor,
|
||||
f: impl Fn(Tensor, Tensor, Tensor) -> Op,
|
||||
) -> Self {
|
||||
let op = if arg1.track_op() || arg2.track_op() || arg3.track_op() {
|
||||
Some(f(arg1.clone(), arg2.clone(), arg3.clone()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Self(op)
|
||||
}
|
||||
|
||||
pub(crate) fn new<A: AsRef<Tensor>>(args: &[A], f: impl Fn(Vec<Tensor>) -> Op) -> Self {
|
||||
let op = if args.iter().any(|arg| arg.as_ref().track_op()) {
|
||||
let args: Vec<Tensor> = args.iter().map(|arg| arg.as_ref().clone()).collect();
|
||||
Some(f(args))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Self(op)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Deref for BackpropOp {
|
||||
type Target = Option<Op>;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,7 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp};
|
||||
use crate::op::{
|
||||
BackpropOp, BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp,
|
||||
};
|
||||
use crate::shape::{Dim, Dims};
|
||||
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use std::sync::{Arc, RwLock};
|
||||
@ -33,7 +35,7 @@ pub struct Tensor_ {
|
||||
// that's tricky to encode in the current setup.
|
||||
storage: Arc<RwLock<Storage>>,
|
||||
layout: Layout,
|
||||
op: Option<Op>,
|
||||
op: BackpropOp,
|
||||
is_variable: bool,
|
||||
dtype: DType,
|
||||
device: Device,
|
||||
@ -79,11 +81,7 @@ macro_rules! unary_op {
|
||||
let storage = self
|
||||
.storage()
|
||||
.unary_impl::<crate::op::$op_name>(self.layout())?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Unary(self.clone(), UnaryOp::$op_name))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let op = BackpropOp::new1(self, |s| Op::Unary(s, UnaryOp::$op_name));
|
||||
Ok(from_storage(storage, shape.clone(), op, false))
|
||||
}
|
||||
};
|
||||
@ -98,11 +96,7 @@ macro_rules! binary_op {
|
||||
self.layout(),
|
||||
rhs.layout(),
|
||||
)?;
|
||||
let op = if self.track_op() || rhs.track_op() {
|
||||
Some(Op::Binary(self.clone(), rhs.clone(), BinaryOp::$op_name))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let op = BackpropOp::new2(self, rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
|
||||
Ok(from_storage(storage, shape.clone(), op, false))
|
||||
}
|
||||
};
|
||||
@ -131,7 +125,7 @@ macro_rules! broadcast_binary_op {
|
||||
fn from_storage<S: Into<Shape>>(
|
||||
storage: Storage,
|
||||
shape: S,
|
||||
op: Option<Op>,
|
||||
op: BackpropOp,
|
||||
is_variable: bool,
|
||||
) -> Tensor {
|
||||
let dtype = storage.dtype();
|
||||
@ -155,13 +149,14 @@ impl Tensor {
|
||||
device: &Device,
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
let none = BackpropOp::none();
|
||||
if is_variable {
|
||||
let shape = shape.into();
|
||||
let storage = device.ones(&shape, dtype)?;
|
||||
Ok(from_storage(storage, shape, None, is_variable))
|
||||
Ok(from_storage(storage, shape, none, is_variable))
|
||||
} else {
|
||||
let storage = device.ones(&crate::shape::SCALAR, dtype)?;
|
||||
from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape)
|
||||
from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape)
|
||||
}
|
||||
}
|
||||
|
||||
@ -199,13 +194,14 @@ impl Tensor {
|
||||
device: &Device,
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
let none = BackpropOp::none();
|
||||
if is_variable {
|
||||
let shape = shape.into();
|
||||
let storage = device.zeros(&shape, dtype)?;
|
||||
Ok(from_storage(storage, shape, None, is_variable))
|
||||
Ok(from_storage(storage, shape, none, is_variable))
|
||||
} else {
|
||||
let storage = device.zeros(&crate::shape::SCALAR, dtype)?;
|
||||
from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape)
|
||||
from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape)
|
||||
}
|
||||
}
|
||||
|
||||
@ -246,7 +242,8 @@ impl Tensor {
|
||||
) -> Result<Self> {
|
||||
let s = s.into();
|
||||
let storage = device.rand_uniform(&s, dtype, lo, up)?;
|
||||
Ok(from_storage(storage, s, None, is_variable))
|
||||
let none = BackpropOp::none();
|
||||
Ok(from_storage(storage, s, none, is_variable))
|
||||
}
|
||||
|
||||
/// Creates a new tensor initialized with values sampled uniformly between `lo` and `up`.
|
||||
@ -270,7 +267,8 @@ impl Tensor {
|
||||
) -> Result<Self> {
|
||||
let s = s.into();
|
||||
let storage = device.rand_normal(&s, dtype, mean, std)?;
|
||||
Ok(from_storage(storage, s, None, is_variable))
|
||||
let none = BackpropOp::none();
|
||||
Ok(from_storage(storage, s, none, is_variable))
|
||||
}
|
||||
|
||||
/// Creates a new tensor initialized with values sampled from a normal distribution with the
|
||||
@ -297,7 +295,8 @@ impl Tensor {
|
||||
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
|
||||
}
|
||||
let storage = device.storage(array)?;
|
||||
Ok(from_storage(storage, shape, None, is_variable))
|
||||
let none = BackpropOp::none();
|
||||
Ok(from_storage(storage, shape, none, is_variable))
|
||||
}
|
||||
|
||||
/// Creates a new tensor on the specified device using the content and shape of the input.
|
||||
@ -352,7 +351,8 @@ impl Tensor {
|
||||
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
|
||||
}
|
||||
let storage = device.storage_owned(data)?;
|
||||
Ok(from_storage(storage, shape, None, is_variable))
|
||||
let none = BackpropOp::none();
|
||||
Ok(from_storage(storage, shape, none, is_variable))
|
||||
}
|
||||
|
||||
/// Creates a new tensor initialized with values from the input vector. The number of elements
|
||||
@ -500,26 +500,14 @@ impl Tensor {
|
||||
/// ```
|
||||
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
|
||||
let storage = self.storage().affine(self.layout(), mul, add)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Affine {
|
||||
arg: self.clone(),
|
||||
mul,
|
||||
add,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let op = BackpropOp::new1(self, |arg| Op::Affine { arg, mul, add });
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
/// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor.
|
||||
pub fn elu(&self, alpha: f64) -> Result<Self> {
|
||||
let storage = self.storage().elu(self.layout(), alpha)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Elu(self.clone(), alpha))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let op = BackpropOp::new1(self, |t| Op::Elu(t, alpha));
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
@ -554,11 +542,7 @@ impl Tensor {
|
||||
if start == 0 && dims[dim] == len {
|
||||
Ok(self.clone())
|
||||
} else {
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Narrow(self.clone(), dim, start, len))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let op = BackpropOp::new1(self, |t| Op::Narrow(t, dim, start, len));
|
||||
let layout = self.layout().narrow(dim, start, len)?;
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
@ -602,11 +586,7 @@ impl Tensor {
|
||||
let mut storage = self.storage().unary_impl::<crate::op::Exp>(self.layout())?;
|
||||
// The resulting storage is contiguous.
|
||||
storage.divide_by_sum_over_dim(shape, dim)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Softmax(self.clone(), dim))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let op = BackpropOp::new1(self, |arg| Op::Softmax(arg, dim));
|
||||
Ok(from_storage(storage, shape.clone(), op, false))
|
||||
}
|
||||
}
|
||||
@ -638,11 +618,7 @@ impl Tensor {
|
||||
let storage = self.storage().reduce_op(op, self.layout(), &[dim])?;
|
||||
let mut dims = self.dims().to_vec();
|
||||
dims[dim] = 1;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Reduce(self.clone(), op, dims.to_vec()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let op = BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()));
|
||||
let res = from_storage(storage, dims, op, false);
|
||||
if keepdim {
|
||||
Ok(res)
|
||||
@ -660,11 +636,7 @@ impl Tensor {
|
||||
for &sum_dim in sum_dims.iter() {
|
||||
dims[sum_dim] = 1
|
||||
}
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Reduce(self.clone(), ReduceOp::Sum, dims.to_vec()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let op = BackpropOp::new1(self, |a| Op::Reduce(a, ReduceOp::Sum, dims.to_vec()));
|
||||
let sum = from_storage(storage, dims, op, false);
|
||||
if keepdim {
|
||||
Ok(sum)
|
||||
@ -738,11 +710,7 @@ impl Tensor {
|
||||
let storage = self
|
||||
.storage()
|
||||
.cmp(op, &rhs.storage(), self.layout(), rhs.layout())?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Cmp(self.clone(), op))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let op = BackpropOp::new1(self, |a| Op::Cmp(a, op));
|
||||
Ok(from_storage(storage, shape.dims(), op, false))
|
||||
}
|
||||
|
||||
@ -807,16 +775,12 @@ impl Tensor {
|
||||
let storage =
|
||||
self.storage()
|
||||
.conv1d(self.layout(), &kernel.storage(), kernel.layout(), ¶ms)?;
|
||||
let op = if self.track_op() || kernel.track_op() {
|
||||
Some(Op::Conv1D {
|
||||
arg: self.clone(),
|
||||
kernel: kernel.clone(),
|
||||
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D {
|
||||
arg,
|
||||
kernel,
|
||||
padding,
|
||||
stride,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
});
|
||||
let out_dims = params.out_dims();
|
||||
Ok(from_storage(storage, out_dims, op, false))
|
||||
}
|
||||
@ -867,11 +831,7 @@ impl Tensor {
|
||||
self.layout(),
|
||||
rhs.layout(),
|
||||
)?;
|
||||
let op = if self.track_op() || rhs.track_op() {
|
||||
Some(Op::Matmul(self.clone(), rhs.clone()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let op = BackpropOp::new2(self, rhs, Op::Matmul);
|
||||
Ok(from_storage(storage, c_shape, op, false))
|
||||
}
|
||||
|
||||
@ -888,15 +848,7 @@ impl Tensor {
|
||||
&on_false.storage(),
|
||||
on_false.layout(),
|
||||
)?;
|
||||
let op = if self.track_op() || on_true.track_op() || on_false.track_op() {
|
||||
Some(Op::WhereCond(
|
||||
self.clone(),
|
||||
on_true.clone(),
|
||||
on_false.clone(),
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let op = BackpropOp::new3(self, on_true, on_false, Op::WhereCond);
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
@ -937,11 +889,7 @@ impl Tensor {
|
||||
.storage()
|
||||
.embedding(ids.layout(), &rhs.storage(), rhs.layout())?;
|
||||
let shape: Shape = (seq_len, hidden_size).into();
|
||||
let op = if ids.track_op() || rhs.track_op() {
|
||||
Some(Op::Embedding(ids.clone(), rhs.clone()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let op = BackpropOp::new2(ids, rhs, Op::Embedding);
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
@ -983,16 +931,9 @@ impl Tensor {
|
||||
source.layout(),
|
||||
dim,
|
||||
)?;
|
||||
let op = if indexes.track_op() || self.track_op() {
|
||||
Some(Op::ScatterAdd(
|
||||
self.clone(),
|
||||
indexes.clone(),
|
||||
source.clone(),
|
||||
dim,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
|
||||
Op::ScatterAdd(t1, t2, t3, dim)
|
||||
});
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
@ -1038,16 +979,9 @@ impl Tensor {
|
||||
source.layout(),
|
||||
dim,
|
||||
)?;
|
||||
let op = if indexes.track_op() || self.track_op() {
|
||||
Some(Op::IndexAdd(
|
||||
self.clone(),
|
||||
indexes.clone(),
|
||||
source.clone(),
|
||||
dim,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
|
||||
Op::IndexAdd(t1, t2, t3, dim)
|
||||
});
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
@ -1077,11 +1011,7 @@ impl Tensor {
|
||||
let storage =
|
||||
self.storage()
|
||||
.gather(self.layout(), &indexes.storage(), indexes.layout(), dim)?;
|
||||
let op = if indexes.track_op() || self.track_op() {
|
||||
Some(Op::Gather(self.clone(), indexes.clone(), dim))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let op = BackpropOp::new2(self, indexes, |t1, t2| Op::Gather(t1, t2, dim));
|
||||
Ok(from_storage(storage, indexes.shape(), op, false))
|
||||
}
|
||||
|
||||
@ -1104,11 +1034,7 @@ impl Tensor {
|
||||
)?;
|
||||
let mut dims = self.dims().to_vec();
|
||||
dims[dim] = indexes_len;
|
||||
let op = if indexes.track_op() || self.track_op() {
|
||||
Some(Op::IndexSelect(self.clone(), indexes.clone(), dim))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let op = BackpropOp::new2(self, indexes, |t1, t2| Op::IndexSelect(t1, t2, dim));
|
||||
Ok(from_storage(storage, dims, op, false))
|
||||
}
|
||||
|
||||
@ -1404,11 +1330,7 @@ impl Tensor {
|
||||
pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> {
|
||||
let dim1 = dim1.to_index(self.shape(), "transpose")?;
|
||||
let dim2 = dim2.to_index(self.shape(), "transpose")?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Transpose(self.clone(), dim1, dim2))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let op = BackpropOp::new1(self, |t| Op::Transpose(t, dim1, dim2));
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
@ -1434,11 +1356,7 @@ impl Tensor {
|
||||
/// Compared to clone, this copies the actual storage but may fail because of running out of
|
||||
/// memory.
|
||||
pub fn copy(&self) -> Result<Tensor> {
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Copy(self.clone()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let op = BackpropOp::new1(self, Op::Copy);
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: Arc::new(RwLock::new(self.storage().try_clone(self.layout())?)),
|
||||
@ -1458,7 +1376,7 @@ impl Tensor {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
layout: self.layout.clone(),
|
||||
op: None,
|
||||
op: BackpropOp::none(),
|
||||
is_variable: false,
|
||||
dtype: self.dtype,
|
||||
device: self.device.clone(),
|
||||
@ -1484,11 +1402,7 @@ impl Tensor {
|
||||
}
|
||||
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
|
||||
};
|
||||
let op = if self.track_op() {
|
||||
Some(Op::ToDevice(self.clone()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let op = BackpropOp::new1(self, Op::ToDevice);
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: Arc::new(RwLock::new(storage)),
|
||||
@ -1519,16 +1433,11 @@ impl Tensor {
|
||||
/// any value, the dimension `t_a` must be equal to `i_a` if `i_a` is different from 1. If
|
||||
/// `i_a` is equal to 1, any value can be used.
|
||||
pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Broadcast(self.clone()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
layout: self.layout.broadcast_as(shape)?,
|
||||
op,
|
||||
op: BackpropOp::new1(self, Op::Broadcast),
|
||||
is_variable: false,
|
||||
dtype: self.dtype,
|
||||
device: self.device.clone(),
|
||||
@ -1557,11 +1466,7 @@ impl Tensor {
|
||||
} else {
|
||||
let shape = self.shape();
|
||||
let storage = self.storage().to_dtype(self.layout(), dtype)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::ToDType(self.clone()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let op = BackpropOp::new1(self, Op::ToDType);
|
||||
Ok(from_storage(storage, shape.clone(), op, false))
|
||||
}
|
||||
}
|
||||
@ -1576,11 +1481,7 @@ impl Tensor {
|
||||
let mut storage = self.device().zeros(shape, self.dtype())?;
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Copy(self.clone()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let op = BackpropOp::new1(self, Op::Copy);
|
||||
Ok(from_storage(storage, shape.clone(), op, false))
|
||||
}
|
||||
}
|
||||
@ -1612,11 +1513,7 @@ impl Tensor {
|
||||
}
|
||||
.bt());
|
||||
}
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Reshape(self.clone()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let op = BackpropOp::new1(self, Op::Reshape);
|
||||
if self.is_contiguous() {
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
@ -1820,12 +1717,7 @@ impl Tensor {
|
||||
offsets.push(next_offset);
|
||||
}
|
||||
let shape = Shape::from(cat_dims);
|
||||
let op = if args.iter().any(|arg| arg.as_ref().track_op()) {
|
||||
let args: Vec<Tensor> = args.iter().map(|arg| arg.as_ref().clone()).collect();
|
||||
Some(Op::Cat(args, 0))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let op = BackpropOp::new(args, |args| Op::Cat(args, 0));
|
||||
let mut storage = device.zeros(&shape, dtype)?;
|
||||
for (arg, &offset) in args.iter().zip(offsets.iter()) {
|
||||
let arg = arg.as_ref();
|
||||
@ -1865,11 +1757,7 @@ impl Tensor {
|
||||
let (storage, shape) = self
|
||||
.storage()
|
||||
.custom_op1(self.layout(), c.as_ref().as_ref())?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::CustomOp1(self.clone(), c))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
@ -1885,11 +1773,7 @@ impl Tensor {
|
||||
rhs.layout(),
|
||||
c.as_ref().as_ref(),
|
||||
)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::CustomOp2(self.clone(), rhs.clone(), c))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone()));
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
@ -1907,11 +1791,9 @@ impl Tensor {
|
||||
t3.layout(),
|
||||
c.as_ref().as_ref(),
|
||||
)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::CustomOp3(self.clone(), t2.clone(), t3.clone(), c))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| {
|
||||
Op::CustomOp3(t1, t2, t3, c.clone())
|
||||
});
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user