Add an abstract backprop op type (#240)

* Start adding the backprop op type.

* More backprop ops.

* Finish the backprop op.
This commit is contained in:
Laurent Mazare
2023-07-25 14:07:40 +01:00
committed by GitHub
parent be9c26180c
commit c97d51243c
2 changed files with 120 additions and 178 deletions

View File

@ -58,7 +58,7 @@ pub enum UnaryOp {
} }
#[derive(Clone)] #[derive(Clone)]
pub(crate) enum Op { pub enum Op {
Binary(Tensor, Tensor, BinaryOp), Binary(Tensor, Tensor, BinaryOp),
Unary(Tensor, UnaryOp), Unary(Tensor, UnaryOp),
Cmp(Tensor, CmpOp), Cmp(Tensor, CmpOp),
@ -512,3 +512,63 @@ impl UnaryOpT for Relu {
v v
} }
} }
/// `BackpropOp` is a wrapper around `Option<Op>`. The main goal is to ensure that dependencies are
/// properly checked when creating a new value
#[derive(Clone)]
pub struct BackpropOp(Option<Op>);
impl BackpropOp {
pub(crate) fn none() -> Self {
BackpropOp(None)
}
pub(crate) fn new1(arg: &Tensor, f: impl Fn(Tensor) -> Op) -> Self {
let op = if arg.track_op() {
Some(f(arg.clone()))
} else {
None
};
Self(op)
}
pub(crate) fn new2(arg1: &Tensor, arg2: &Tensor, f: impl Fn(Tensor, Tensor) -> Op) -> Self {
let op = if arg1.track_op() || arg2.track_op() {
Some(f(arg1.clone(), arg2.clone()))
} else {
None
};
Self(op)
}
pub(crate) fn new3(
arg1: &Tensor,
arg2: &Tensor,
arg3: &Tensor,
f: impl Fn(Tensor, Tensor, Tensor) -> Op,
) -> Self {
let op = if arg1.track_op() || arg2.track_op() || arg3.track_op() {
Some(f(arg1.clone(), arg2.clone(), arg3.clone()))
} else {
None
};
Self(op)
}
pub(crate) fn new<A: AsRef<Tensor>>(args: &[A], f: impl Fn(Vec<Tensor>) -> Op) -> Self {
let op = if args.iter().any(|arg| arg.as_ref().track_op()) {
let args: Vec<Tensor> = args.iter().map(|arg| arg.as_ref().clone()).collect();
Some(f(args))
} else {
None
};
Self(op)
}
}
impl std::ops::Deref for BackpropOp {
type Target = Option<Op>;
fn deref(&self) -> &Self::Target {
&self.0
}
}

View File

@ -1,5 +1,7 @@
use crate::backend::{BackendDevice, BackendStorage}; use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp}; use crate::op::{
BackpropOp, BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp,
};
use crate::shape::{Dim, Dims}; use crate::shape::{Dim, Dims};
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape}; use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
@ -33,7 +35,7 @@ pub struct Tensor_ {
// that's tricky to encode in the current setup. // that's tricky to encode in the current setup.
storage: Arc<RwLock<Storage>>, storage: Arc<RwLock<Storage>>,
layout: Layout, layout: Layout,
op: Option<Op>, op: BackpropOp,
is_variable: bool, is_variable: bool,
dtype: DType, dtype: DType,
device: Device, device: Device,
@ -79,11 +81,7 @@ macro_rules! unary_op {
let storage = self let storage = self
.storage() .storage()
.unary_impl::<crate::op::$op_name>(self.layout())?; .unary_impl::<crate::op::$op_name>(self.layout())?;
let op = if self.track_op() { let op = BackpropOp::new1(self, |s| Op::Unary(s, UnaryOp::$op_name));
Some(Op::Unary(self.clone(), UnaryOp::$op_name))
} else {
None
};
Ok(from_storage(storage, shape.clone(), op, false)) Ok(from_storage(storage, shape.clone(), op, false))
} }
}; };
@ -98,11 +96,7 @@ macro_rules! binary_op {
self.layout(), self.layout(),
rhs.layout(), rhs.layout(),
)?; )?;
let op = if self.track_op() || rhs.track_op() { let op = BackpropOp::new2(self, rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
Some(Op::Binary(self.clone(), rhs.clone(), BinaryOp::$op_name))
} else {
None
};
Ok(from_storage(storage, shape.clone(), op, false)) Ok(from_storage(storage, shape.clone(), op, false))
} }
}; };
@ -131,7 +125,7 @@ macro_rules! broadcast_binary_op {
fn from_storage<S: Into<Shape>>( fn from_storage<S: Into<Shape>>(
storage: Storage, storage: Storage,
shape: S, shape: S,
op: Option<Op>, op: BackpropOp,
is_variable: bool, is_variable: bool,
) -> Tensor { ) -> Tensor {
let dtype = storage.dtype(); let dtype = storage.dtype();
@ -155,13 +149,14 @@ impl Tensor {
device: &Device, device: &Device,
is_variable: bool, is_variable: bool,
) -> Result<Self> { ) -> Result<Self> {
let none = BackpropOp::none();
if is_variable { if is_variable {
let shape = shape.into(); let shape = shape.into();
let storage = device.ones(&shape, dtype)?; let storage = device.ones(&shape, dtype)?;
Ok(from_storage(storage, shape, None, is_variable)) Ok(from_storage(storage, shape, none, is_variable))
} else { } else {
let storage = device.ones(&crate::shape::SCALAR, dtype)?; let storage = device.ones(&crate::shape::SCALAR, dtype)?;
from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape) from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape)
} }
} }
@ -199,13 +194,14 @@ impl Tensor {
device: &Device, device: &Device,
is_variable: bool, is_variable: bool,
) -> Result<Self> { ) -> Result<Self> {
let none = BackpropOp::none();
if is_variable { if is_variable {
let shape = shape.into(); let shape = shape.into();
let storage = device.zeros(&shape, dtype)?; let storage = device.zeros(&shape, dtype)?;
Ok(from_storage(storage, shape, None, is_variable)) Ok(from_storage(storage, shape, none, is_variable))
} else { } else {
let storage = device.zeros(&crate::shape::SCALAR, dtype)?; let storage = device.zeros(&crate::shape::SCALAR, dtype)?;
from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape) from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape)
} }
} }
@ -246,7 +242,8 @@ impl Tensor {
) -> Result<Self> { ) -> Result<Self> {
let s = s.into(); let s = s.into();
let storage = device.rand_uniform(&s, dtype, lo, up)?; let storage = device.rand_uniform(&s, dtype, lo, up)?;
Ok(from_storage(storage, s, None, is_variable)) let none = BackpropOp::none();
Ok(from_storage(storage, s, none, is_variable))
} }
/// Creates a new tensor initialized with values sampled uniformly between `lo` and `up`. /// Creates a new tensor initialized with values sampled uniformly between `lo` and `up`.
@ -270,7 +267,8 @@ impl Tensor {
) -> Result<Self> { ) -> Result<Self> {
let s = s.into(); let s = s.into();
let storage = device.rand_normal(&s, dtype, mean, std)?; let storage = device.rand_normal(&s, dtype, mean, std)?;
Ok(from_storage(storage, s, None, is_variable)) let none = BackpropOp::none();
Ok(from_storage(storage, s, none, is_variable))
} }
/// Creates a new tensor initialized with values sampled from a normal distribution with the /// Creates a new tensor initialized with values sampled from a normal distribution with the
@ -297,7 +295,8 @@ impl Tensor {
return Err(Error::ShapeMismatch { buffer_size, shape }.bt()); return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
} }
let storage = device.storage(array)?; let storage = device.storage(array)?;
Ok(from_storage(storage, shape, None, is_variable)) let none = BackpropOp::none();
Ok(from_storage(storage, shape, none, is_variable))
} }
/// Creates a new tensor on the specified device using the content and shape of the input. /// Creates a new tensor on the specified device using the content and shape of the input.
@ -352,7 +351,8 @@ impl Tensor {
return Err(Error::ShapeMismatch { buffer_size, shape }.bt()); return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
} }
let storage = device.storage_owned(data)?; let storage = device.storage_owned(data)?;
Ok(from_storage(storage, shape, None, is_variable)) let none = BackpropOp::none();
Ok(from_storage(storage, shape, none, is_variable))
} }
/// Creates a new tensor initialized with values from the input vector. The number of elements /// Creates a new tensor initialized with values from the input vector. The number of elements
@ -500,26 +500,14 @@ impl Tensor {
/// ``` /// ```
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> { pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
let storage = self.storage().affine(self.layout(), mul, add)?; let storage = self.storage().affine(self.layout(), mul, add)?;
let op = if self.track_op() { let op = BackpropOp::new1(self, |arg| Op::Affine { arg, mul, add });
Some(Op::Affine {
arg: self.clone(),
mul,
add,
})
} else {
None
};
Ok(from_storage(storage, self.shape(), op, false)) Ok(from_storage(storage, self.shape(), op, false))
} }
/// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor. /// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor.
pub fn elu(&self, alpha: f64) -> Result<Self> { pub fn elu(&self, alpha: f64) -> Result<Self> {
let storage = self.storage().elu(self.layout(), alpha)?; let storage = self.storage().elu(self.layout(), alpha)?;
let op = if self.track_op() { let op = BackpropOp::new1(self, |t| Op::Elu(t, alpha));
Some(Op::Elu(self.clone(), alpha))
} else {
None
};
Ok(from_storage(storage, self.shape(), op, false)) Ok(from_storage(storage, self.shape(), op, false))
} }
@ -554,11 +542,7 @@ impl Tensor {
if start == 0 && dims[dim] == len { if start == 0 && dims[dim] == len {
Ok(self.clone()) Ok(self.clone())
} else { } else {
let op = if self.track_op() { let op = BackpropOp::new1(self, |t| Op::Narrow(t, dim, start, len));
Some(Op::Narrow(self.clone(), dim, start, len))
} else {
None
};
let layout = self.layout().narrow(dim, start, len)?; let layout = self.layout().narrow(dim, start, len)?;
let tensor_ = Tensor_ { let tensor_ = Tensor_ {
id: TensorId::new(), id: TensorId::new(),
@ -602,11 +586,7 @@ impl Tensor {
let mut storage = self.storage().unary_impl::<crate::op::Exp>(self.layout())?; let mut storage = self.storage().unary_impl::<crate::op::Exp>(self.layout())?;
// The resulting storage is contiguous. // The resulting storage is contiguous.
storage.divide_by_sum_over_dim(shape, dim)?; storage.divide_by_sum_over_dim(shape, dim)?;
let op = if self.track_op() { let op = BackpropOp::new1(self, |arg| Op::Softmax(arg, dim));
Some(Op::Softmax(self.clone(), dim))
} else {
None
};
Ok(from_storage(storage, shape.clone(), op, false)) Ok(from_storage(storage, shape.clone(), op, false))
} }
} }
@ -638,11 +618,7 @@ impl Tensor {
let storage = self.storage().reduce_op(op, self.layout(), &[dim])?; let storage = self.storage().reduce_op(op, self.layout(), &[dim])?;
let mut dims = self.dims().to_vec(); let mut dims = self.dims().to_vec();
dims[dim] = 1; dims[dim] = 1;
let op = if self.track_op() { let op = BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()));
Some(Op::Reduce(self.clone(), op, dims.to_vec()))
} else {
None
};
let res = from_storage(storage, dims, op, false); let res = from_storage(storage, dims, op, false);
if keepdim { if keepdim {
Ok(res) Ok(res)
@ -660,11 +636,7 @@ impl Tensor {
for &sum_dim in sum_dims.iter() { for &sum_dim in sum_dims.iter() {
dims[sum_dim] = 1 dims[sum_dim] = 1
} }
let op = if self.track_op() { let op = BackpropOp::new1(self, |a| Op::Reduce(a, ReduceOp::Sum, dims.to_vec()));
Some(Op::Reduce(self.clone(), ReduceOp::Sum, dims.to_vec()))
} else {
None
};
let sum = from_storage(storage, dims, op, false); let sum = from_storage(storage, dims, op, false);
if keepdim { if keepdim {
Ok(sum) Ok(sum)
@ -738,11 +710,7 @@ impl Tensor {
let storage = self let storage = self
.storage() .storage()
.cmp(op, &rhs.storage(), self.layout(), rhs.layout())?; .cmp(op, &rhs.storage(), self.layout(), rhs.layout())?;
let op = if self.track_op() { let op = BackpropOp::new1(self, |a| Op::Cmp(a, op));
Some(Op::Cmp(self.clone(), op))
} else {
None
};
Ok(from_storage(storage, shape.dims(), op, false)) Ok(from_storage(storage, shape.dims(), op, false))
} }
@ -807,16 +775,12 @@ impl Tensor {
let storage = let storage =
self.storage() self.storage()
.conv1d(self.layout(), &kernel.storage(), kernel.layout(), &params)?; .conv1d(self.layout(), &kernel.storage(), kernel.layout(), &params)?;
let op = if self.track_op() || kernel.track_op() { let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D {
Some(Op::Conv1D { arg,
arg: self.clone(), kernel,
kernel: kernel.clone(),
padding, padding,
stride, stride,
}) });
} else {
None
};
let out_dims = params.out_dims(); let out_dims = params.out_dims();
Ok(from_storage(storage, out_dims, op, false)) Ok(from_storage(storage, out_dims, op, false))
} }
@ -867,11 +831,7 @@ impl Tensor {
self.layout(), self.layout(),
rhs.layout(), rhs.layout(),
)?; )?;
let op = if self.track_op() || rhs.track_op() { let op = BackpropOp::new2(self, rhs, Op::Matmul);
Some(Op::Matmul(self.clone(), rhs.clone()))
} else {
None
};
Ok(from_storage(storage, c_shape, op, false)) Ok(from_storage(storage, c_shape, op, false))
} }
@ -888,15 +848,7 @@ impl Tensor {
&on_false.storage(), &on_false.storage(),
on_false.layout(), on_false.layout(),
)?; )?;
let op = if self.track_op() || on_true.track_op() || on_false.track_op() { let op = BackpropOp::new3(self, on_true, on_false, Op::WhereCond);
Some(Op::WhereCond(
self.clone(),
on_true.clone(),
on_false.clone(),
))
} else {
None
};
Ok(from_storage(storage, shape, op, false)) Ok(from_storage(storage, shape, op, false))
} }
@ -937,11 +889,7 @@ impl Tensor {
.storage() .storage()
.embedding(ids.layout(), &rhs.storage(), rhs.layout())?; .embedding(ids.layout(), &rhs.storage(), rhs.layout())?;
let shape: Shape = (seq_len, hidden_size).into(); let shape: Shape = (seq_len, hidden_size).into();
let op = if ids.track_op() || rhs.track_op() { let op = BackpropOp::new2(ids, rhs, Op::Embedding);
Some(Op::Embedding(ids.clone(), rhs.clone()))
} else {
None
};
Ok(from_storage(storage, shape, op, false)) Ok(from_storage(storage, shape, op, false))
} }
@ -983,16 +931,9 @@ impl Tensor {
source.layout(), source.layout(),
dim, dim,
)?; )?;
let op = if indexes.track_op() || self.track_op() { let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
Some(Op::ScatterAdd( Op::ScatterAdd(t1, t2, t3, dim)
self.clone(), });
indexes.clone(),
source.clone(),
dim,
))
} else {
None
};
Ok(from_storage(storage, self.shape(), op, false)) Ok(from_storage(storage, self.shape(), op, false))
} }
@ -1038,16 +979,9 @@ impl Tensor {
source.layout(), source.layout(),
dim, dim,
)?; )?;
let op = if indexes.track_op() || self.track_op() { let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
Some(Op::IndexAdd( Op::IndexAdd(t1, t2, t3, dim)
self.clone(), });
indexes.clone(),
source.clone(),
dim,
))
} else {
None
};
Ok(from_storage(storage, self.shape(), op, false)) Ok(from_storage(storage, self.shape(), op, false))
} }
@ -1077,11 +1011,7 @@ impl Tensor {
let storage = let storage =
self.storage() self.storage()
.gather(self.layout(), &indexes.storage(), indexes.layout(), dim)?; .gather(self.layout(), &indexes.storage(), indexes.layout(), dim)?;
let op = if indexes.track_op() || self.track_op() { let op = BackpropOp::new2(self, indexes, |t1, t2| Op::Gather(t1, t2, dim));
Some(Op::Gather(self.clone(), indexes.clone(), dim))
} else {
None
};
Ok(from_storage(storage, indexes.shape(), op, false)) Ok(from_storage(storage, indexes.shape(), op, false))
} }
@ -1104,11 +1034,7 @@ impl Tensor {
)?; )?;
let mut dims = self.dims().to_vec(); let mut dims = self.dims().to_vec();
dims[dim] = indexes_len; dims[dim] = indexes_len;
let op = if indexes.track_op() || self.track_op() { let op = BackpropOp::new2(self, indexes, |t1, t2| Op::IndexSelect(t1, t2, dim));
Some(Op::IndexSelect(self.clone(), indexes.clone(), dim))
} else {
None
};
Ok(from_storage(storage, dims, op, false)) Ok(from_storage(storage, dims, op, false))
} }
@ -1404,11 +1330,7 @@ impl Tensor {
pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> { pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> {
let dim1 = dim1.to_index(self.shape(), "transpose")?; let dim1 = dim1.to_index(self.shape(), "transpose")?;
let dim2 = dim2.to_index(self.shape(), "transpose")?; let dim2 = dim2.to_index(self.shape(), "transpose")?;
let op = if self.track_op() { let op = BackpropOp::new1(self, |t| Op::Transpose(t, dim1, dim2));
Some(Op::Transpose(self.clone(), dim1, dim2))
} else {
None
};
let tensor_ = Tensor_ { let tensor_ = Tensor_ {
id: TensorId::new(), id: TensorId::new(),
storage: self.storage.clone(), storage: self.storage.clone(),
@ -1434,11 +1356,7 @@ impl Tensor {
/// Compared to clone, this copies the actual storage but may fail because of running out of /// Compared to clone, this copies the actual storage but may fail because of running out of
/// memory. /// memory.
pub fn copy(&self) -> Result<Tensor> { pub fn copy(&self) -> Result<Tensor> {
let op = if self.track_op() { let op = BackpropOp::new1(self, Op::Copy);
Some(Op::Copy(self.clone()))
} else {
None
};
let tensor_ = Tensor_ { let tensor_ = Tensor_ {
id: TensorId::new(), id: TensorId::new(),
storage: Arc::new(RwLock::new(self.storage().try_clone(self.layout())?)), storage: Arc::new(RwLock::new(self.storage().try_clone(self.layout())?)),
@ -1458,7 +1376,7 @@ impl Tensor {
id: TensorId::new(), id: TensorId::new(),
storage: self.storage.clone(), storage: self.storage.clone(),
layout: self.layout.clone(), layout: self.layout.clone(),
op: None, op: BackpropOp::none(),
is_variable: false, is_variable: false,
dtype: self.dtype, dtype: self.dtype,
device: self.device.clone(), device: self.device.clone(),
@ -1484,11 +1402,7 @@ impl Tensor {
} }
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()), (Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
}; };
let op = if self.track_op() { let op = BackpropOp::new1(self, Op::ToDevice);
Some(Op::ToDevice(self.clone()))
} else {
None
};
let tensor_ = Tensor_ { let tensor_ = Tensor_ {
id: TensorId::new(), id: TensorId::new(),
storage: Arc::new(RwLock::new(storage)), storage: Arc::new(RwLock::new(storage)),
@ -1519,16 +1433,11 @@ impl Tensor {
/// any value, the dimension `t_a` must be equal to `i_a` if `i_a` is different from 1. If /// any value, the dimension `t_a` must be equal to `i_a` if `i_a` is different from 1. If
/// `i_a` is equal to 1, any value can be used. /// `i_a` is equal to 1, any value can be used.
pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> { pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
let op = if self.track_op() {
Some(Op::Broadcast(self.clone()))
} else {
None
};
let tensor_ = Tensor_ { let tensor_ = Tensor_ {
id: TensorId::new(), id: TensorId::new(),
storage: self.storage.clone(), storage: self.storage.clone(),
layout: self.layout.broadcast_as(shape)?, layout: self.layout.broadcast_as(shape)?,
op, op: BackpropOp::new1(self, Op::Broadcast),
is_variable: false, is_variable: false,
dtype: self.dtype, dtype: self.dtype,
device: self.device.clone(), device: self.device.clone(),
@ -1557,11 +1466,7 @@ impl Tensor {
} else { } else {
let shape = self.shape(); let shape = self.shape();
let storage = self.storage().to_dtype(self.layout(), dtype)?; let storage = self.storage().to_dtype(self.layout(), dtype)?;
let op = if self.track_op() { let op = BackpropOp::new1(self, Op::ToDType);
Some(Op::ToDType(self.clone()))
} else {
None
};
Ok(from_storage(storage, shape.clone(), op, false)) Ok(from_storage(storage, shape.clone(), op, false))
} }
} }
@ -1576,11 +1481,7 @@ impl Tensor {
let mut storage = self.device().zeros(shape, self.dtype())?; let mut storage = self.device().zeros(shape, self.dtype())?;
self.storage() self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?; .copy_strided_src(&mut storage, 0, self.layout())?;
let op = if self.track_op() { let op = BackpropOp::new1(self, Op::Copy);
Some(Op::Copy(self.clone()))
} else {
None
};
Ok(from_storage(storage, shape.clone(), op, false)) Ok(from_storage(storage, shape.clone(), op, false))
} }
} }
@ -1612,11 +1513,7 @@ impl Tensor {
} }
.bt()); .bt());
} }
let op = if self.track_op() { let op = BackpropOp::new1(self, Op::Reshape);
Some(Op::Reshape(self.clone()))
} else {
None
};
if self.is_contiguous() { if self.is_contiguous() {
let tensor_ = Tensor_ { let tensor_ = Tensor_ {
id: TensorId::new(), id: TensorId::new(),
@ -1820,12 +1717,7 @@ impl Tensor {
offsets.push(next_offset); offsets.push(next_offset);
} }
let shape = Shape::from(cat_dims); let shape = Shape::from(cat_dims);
let op = if args.iter().any(|arg| arg.as_ref().track_op()) { let op = BackpropOp::new(args, |args| Op::Cat(args, 0));
let args: Vec<Tensor> = args.iter().map(|arg| arg.as_ref().clone()).collect();
Some(Op::Cat(args, 0))
} else {
None
};
let mut storage = device.zeros(&shape, dtype)?; let mut storage = device.zeros(&shape, dtype)?;
for (arg, &offset) in args.iter().zip(offsets.iter()) { for (arg, &offset) in args.iter().zip(offsets.iter()) {
let arg = arg.as_ref(); let arg = arg.as_ref();
@ -1865,11 +1757,7 @@ impl Tensor {
let (storage, shape) = self let (storage, shape) = self
.storage() .storage()
.custom_op1(self.layout(), c.as_ref().as_ref())?; .custom_op1(self.layout(), c.as_ref().as_ref())?;
let op = if self.track_op() { let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
Some(Op::CustomOp1(self.clone(), c))
} else {
None
};
Ok(from_storage(storage, shape, op, false)) Ok(from_storage(storage, shape, op, false))
} }
@ -1885,11 +1773,7 @@ impl Tensor {
rhs.layout(), rhs.layout(),
c.as_ref().as_ref(), c.as_ref().as_ref(),
)?; )?;
let op = if self.track_op() { let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone()));
Some(Op::CustomOp2(self.clone(), rhs.clone(), c))
} else {
None
};
Ok(from_storage(storage, shape, op, false)) Ok(from_storage(storage, shape, op, false))
} }
@ -1907,11 +1791,9 @@ impl Tensor {
t3.layout(), t3.layout(),
c.as_ref().as_ref(), c.as_ref().as_ref(),
)?; )?;
let op = if self.track_op() { let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| {
Some(Op::CustomOp3(self.clone(), t2.clone(), t3.clone(), c)) Op::CustomOp3(t1, t2, t3, c.clone())
} else { });
None
};
Ok(from_storage(storage, shape, op, false)) Ok(from_storage(storage, shape, op, false))
} }