Add an abstract backprop op type (#240)

* Start adding the backprop op type.

* More backprop ops.

* Finish the backprop op.
This commit is contained in:
Laurent Mazare
2023-07-25 14:07:40 +01:00
committed by GitHub
parent be9c26180c
commit c97d51243c
2 changed files with 120 additions and 178 deletions

View File

@ -58,7 +58,7 @@ pub enum UnaryOp {
}
#[derive(Clone)]
pub(crate) enum Op {
pub enum Op {
Binary(Tensor, Tensor, BinaryOp),
Unary(Tensor, UnaryOp),
Cmp(Tensor, CmpOp),
@ -512,3 +512,63 @@ impl UnaryOpT for Relu {
v
}
}
/// `BackpropOp` is a wrapper around `Option<Op>`. The main goal is to ensure that dependencies are
/// properly checked when creating a new value
#[derive(Clone)]
pub struct BackpropOp(Option<Op>);
impl BackpropOp {
pub(crate) fn none() -> Self {
BackpropOp(None)
}
pub(crate) fn new1(arg: &Tensor, f: impl Fn(Tensor) -> Op) -> Self {
let op = if arg.track_op() {
Some(f(arg.clone()))
} else {
None
};
Self(op)
}
pub(crate) fn new2(arg1: &Tensor, arg2: &Tensor, f: impl Fn(Tensor, Tensor) -> Op) -> Self {
let op = if arg1.track_op() || arg2.track_op() {
Some(f(arg1.clone(), arg2.clone()))
} else {
None
};
Self(op)
}
pub(crate) fn new3(
arg1: &Tensor,
arg2: &Tensor,
arg3: &Tensor,
f: impl Fn(Tensor, Tensor, Tensor) -> Op,
) -> Self {
let op = if arg1.track_op() || arg2.track_op() || arg3.track_op() {
Some(f(arg1.clone(), arg2.clone(), arg3.clone()))
} else {
None
};
Self(op)
}
pub(crate) fn new<A: AsRef<Tensor>>(args: &[A], f: impl Fn(Vec<Tensor>) -> Op) -> Self {
let op = if args.iter().any(|arg| arg.as_ref().track_op()) {
let args: Vec<Tensor> = args.iter().map(|arg| arg.as_ref().clone()).collect();
Some(f(args))
} else {
None
};
Self(op)
}
}
impl std::ops::Deref for BackpropOp {
type Target = Option<Op>;
fn deref(&self) -> &Self::Target {
&self.0
}
}

View File

@ -1,5 +1,7 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp};
use crate::op::{
BackpropOp, BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp,
};
use crate::shape::{Dim, Dims};
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
use std::sync::{Arc, RwLock};
@ -33,7 +35,7 @@ pub struct Tensor_ {
// that's tricky to encode in the current setup.
storage: Arc<RwLock<Storage>>,
layout: Layout,
op: Option<Op>,
op: BackpropOp,
is_variable: bool,
dtype: DType,
device: Device,
@ -79,11 +81,7 @@ macro_rules! unary_op {
let storage = self
.storage()
.unary_impl::<crate::op::$op_name>(self.layout())?;
let op = if self.track_op() {
Some(Op::Unary(self.clone(), UnaryOp::$op_name))
} else {
None
};
let op = BackpropOp::new1(self, |s| Op::Unary(s, UnaryOp::$op_name));
Ok(from_storage(storage, shape.clone(), op, false))
}
};
@ -98,11 +96,7 @@ macro_rules! binary_op {
self.layout(),
rhs.layout(),
)?;
let op = if self.track_op() || rhs.track_op() {
Some(Op::Binary(self.clone(), rhs.clone(), BinaryOp::$op_name))
} else {
None
};
let op = BackpropOp::new2(self, rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
Ok(from_storage(storage, shape.clone(), op, false))
}
};
@ -131,7 +125,7 @@ macro_rules! broadcast_binary_op {
fn from_storage<S: Into<Shape>>(
storage: Storage,
shape: S,
op: Option<Op>,
op: BackpropOp,
is_variable: bool,
) -> Tensor {
let dtype = storage.dtype();
@ -155,13 +149,14 @@ impl Tensor {
device: &Device,
is_variable: bool,
) -> Result<Self> {
let none = BackpropOp::none();
if is_variable {
let shape = shape.into();
let storage = device.ones(&shape, dtype)?;
Ok(from_storage(storage, shape, None, is_variable))
Ok(from_storage(storage, shape, none, is_variable))
} else {
let storage = device.ones(&crate::shape::SCALAR, dtype)?;
from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape)
from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape)
}
}
@ -199,13 +194,14 @@ impl Tensor {
device: &Device,
is_variable: bool,
) -> Result<Self> {
let none = BackpropOp::none();
if is_variable {
let shape = shape.into();
let storage = device.zeros(&shape, dtype)?;
Ok(from_storage(storage, shape, None, is_variable))
Ok(from_storage(storage, shape, none, is_variable))
} else {
let storage = device.zeros(&crate::shape::SCALAR, dtype)?;
from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape)
from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape)
}
}
@ -246,7 +242,8 @@ impl Tensor {
) -> Result<Self> {
let s = s.into();
let storage = device.rand_uniform(&s, dtype, lo, up)?;
Ok(from_storage(storage, s, None, is_variable))
let none = BackpropOp::none();
Ok(from_storage(storage, s, none, is_variable))
}
/// Creates a new tensor initialized with values sampled uniformly between `lo` and `up`.
@ -270,7 +267,8 @@ impl Tensor {
) -> Result<Self> {
let s = s.into();
let storage = device.rand_normal(&s, dtype, mean, std)?;
Ok(from_storage(storage, s, None, is_variable))
let none = BackpropOp::none();
Ok(from_storage(storage, s, none, is_variable))
}
/// Creates a new tensor initialized with values sampled from a normal distribution with the
@ -297,7 +295,8 @@ impl Tensor {
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
}
let storage = device.storage(array)?;
Ok(from_storage(storage, shape, None, is_variable))
let none = BackpropOp::none();
Ok(from_storage(storage, shape, none, is_variable))
}
/// Creates a new tensor on the specified device using the content and shape of the input.
@ -352,7 +351,8 @@ impl Tensor {
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
}
let storage = device.storage_owned(data)?;
Ok(from_storage(storage, shape, None, is_variable))
let none = BackpropOp::none();
Ok(from_storage(storage, shape, none, is_variable))
}
/// Creates a new tensor initialized with values from the input vector. The number of elements
@ -500,26 +500,14 @@ impl Tensor {
/// ```
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
let storage = self.storage().affine(self.layout(), mul, add)?;
let op = if self.track_op() {
Some(Op::Affine {
arg: self.clone(),
mul,
add,
})
} else {
None
};
let op = BackpropOp::new1(self, |arg| Op::Affine { arg, mul, add });
Ok(from_storage(storage, self.shape(), op, false))
}
/// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor.
pub fn elu(&self, alpha: f64) -> Result<Self> {
let storage = self.storage().elu(self.layout(), alpha)?;
let op = if self.track_op() {
Some(Op::Elu(self.clone(), alpha))
} else {
None
};
let op = BackpropOp::new1(self, |t| Op::Elu(t, alpha));
Ok(from_storage(storage, self.shape(), op, false))
}
@ -554,11 +542,7 @@ impl Tensor {
if start == 0 && dims[dim] == len {
Ok(self.clone())
} else {
let op = if self.track_op() {
Some(Op::Narrow(self.clone(), dim, start, len))
} else {
None
};
let op = BackpropOp::new1(self, |t| Op::Narrow(t, dim, start, len));
let layout = self.layout().narrow(dim, start, len)?;
let tensor_ = Tensor_ {
id: TensorId::new(),
@ -602,11 +586,7 @@ impl Tensor {
let mut storage = self.storage().unary_impl::<crate::op::Exp>(self.layout())?;
// The resulting storage is contiguous.
storage.divide_by_sum_over_dim(shape, dim)?;
let op = if self.track_op() {
Some(Op::Softmax(self.clone(), dim))
} else {
None
};
let op = BackpropOp::new1(self, |arg| Op::Softmax(arg, dim));
Ok(from_storage(storage, shape.clone(), op, false))
}
}
@ -638,11 +618,7 @@ impl Tensor {
let storage = self.storage().reduce_op(op, self.layout(), &[dim])?;
let mut dims = self.dims().to_vec();
dims[dim] = 1;
let op = if self.track_op() {
Some(Op::Reduce(self.clone(), op, dims.to_vec()))
} else {
None
};
let op = BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()));
let res = from_storage(storage, dims, op, false);
if keepdim {
Ok(res)
@ -660,11 +636,7 @@ impl Tensor {
for &sum_dim in sum_dims.iter() {
dims[sum_dim] = 1
}
let op = if self.track_op() {
Some(Op::Reduce(self.clone(), ReduceOp::Sum, dims.to_vec()))
} else {
None
};
let op = BackpropOp::new1(self, |a| Op::Reduce(a, ReduceOp::Sum, dims.to_vec()));
let sum = from_storage(storage, dims, op, false);
if keepdim {
Ok(sum)
@ -738,11 +710,7 @@ impl Tensor {
let storage = self
.storage()
.cmp(op, &rhs.storage(), self.layout(), rhs.layout())?;
let op = if self.track_op() {
Some(Op::Cmp(self.clone(), op))
} else {
None
};
let op = BackpropOp::new1(self, |a| Op::Cmp(a, op));
Ok(from_storage(storage, shape.dims(), op, false))
}
@ -807,16 +775,12 @@ impl Tensor {
let storage =
self.storage()
.conv1d(self.layout(), &kernel.storage(), kernel.layout(), &params)?;
let op = if self.track_op() || kernel.track_op() {
Some(Op::Conv1D {
arg: self.clone(),
kernel: kernel.clone(),
padding,
stride,
})
} else {
None
};
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D {
arg,
kernel,
padding,
stride,
});
let out_dims = params.out_dims();
Ok(from_storage(storage, out_dims, op, false))
}
@ -867,11 +831,7 @@ impl Tensor {
self.layout(),
rhs.layout(),
)?;
let op = if self.track_op() || rhs.track_op() {
Some(Op::Matmul(self.clone(), rhs.clone()))
} else {
None
};
let op = BackpropOp::new2(self, rhs, Op::Matmul);
Ok(from_storage(storage, c_shape, op, false))
}
@ -888,15 +848,7 @@ impl Tensor {
&on_false.storage(),
on_false.layout(),
)?;
let op = if self.track_op() || on_true.track_op() || on_false.track_op() {
Some(Op::WhereCond(
self.clone(),
on_true.clone(),
on_false.clone(),
))
} else {
None
};
let op = BackpropOp::new3(self, on_true, on_false, Op::WhereCond);
Ok(from_storage(storage, shape, op, false))
}
@ -937,11 +889,7 @@ impl Tensor {
.storage()
.embedding(ids.layout(), &rhs.storage(), rhs.layout())?;
let shape: Shape = (seq_len, hidden_size).into();
let op = if ids.track_op() || rhs.track_op() {
Some(Op::Embedding(ids.clone(), rhs.clone()))
} else {
None
};
let op = BackpropOp::new2(ids, rhs, Op::Embedding);
Ok(from_storage(storage, shape, op, false))
}
@ -983,16 +931,9 @@ impl Tensor {
source.layout(),
dim,
)?;
let op = if indexes.track_op() || self.track_op() {
Some(Op::ScatterAdd(
self.clone(),
indexes.clone(),
source.clone(),
dim,
))
} else {
None
};
let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
Op::ScatterAdd(t1, t2, t3, dim)
});
Ok(from_storage(storage, self.shape(), op, false))
}
@ -1038,16 +979,9 @@ impl Tensor {
source.layout(),
dim,
)?;
let op = if indexes.track_op() || self.track_op() {
Some(Op::IndexAdd(
self.clone(),
indexes.clone(),
source.clone(),
dim,
))
} else {
None
};
let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
Op::IndexAdd(t1, t2, t3, dim)
});
Ok(from_storage(storage, self.shape(), op, false))
}
@ -1077,11 +1011,7 @@ impl Tensor {
let storage =
self.storage()
.gather(self.layout(), &indexes.storage(), indexes.layout(), dim)?;
let op = if indexes.track_op() || self.track_op() {
Some(Op::Gather(self.clone(), indexes.clone(), dim))
} else {
None
};
let op = BackpropOp::new2(self, indexes, |t1, t2| Op::Gather(t1, t2, dim));
Ok(from_storage(storage, indexes.shape(), op, false))
}
@ -1104,11 +1034,7 @@ impl Tensor {
)?;
let mut dims = self.dims().to_vec();
dims[dim] = indexes_len;
let op = if indexes.track_op() || self.track_op() {
Some(Op::IndexSelect(self.clone(), indexes.clone(), dim))
} else {
None
};
let op = BackpropOp::new2(self, indexes, |t1, t2| Op::IndexSelect(t1, t2, dim));
Ok(from_storage(storage, dims, op, false))
}
@ -1404,11 +1330,7 @@ impl Tensor {
pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> {
let dim1 = dim1.to_index(self.shape(), "transpose")?;
let dim2 = dim2.to_index(self.shape(), "transpose")?;
let op = if self.track_op() {
Some(Op::Transpose(self.clone(), dim1, dim2))
} else {
None
};
let op = BackpropOp::new1(self, |t| Op::Transpose(t, dim1, dim2));
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
@ -1434,11 +1356,7 @@ impl Tensor {
/// Compared to clone, this copies the actual storage but may fail because of running out of
/// memory.
pub fn copy(&self) -> Result<Tensor> {
let op = if self.track_op() {
Some(Op::Copy(self.clone()))
} else {
None
};
let op = BackpropOp::new1(self, Op::Copy);
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: Arc::new(RwLock::new(self.storage().try_clone(self.layout())?)),
@ -1458,7 +1376,7 @@ impl Tensor {
id: TensorId::new(),
storage: self.storage.clone(),
layout: self.layout.clone(),
op: None,
op: BackpropOp::none(),
is_variable: false,
dtype: self.dtype,
device: self.device.clone(),
@ -1484,11 +1402,7 @@ impl Tensor {
}
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
};
let op = if self.track_op() {
Some(Op::ToDevice(self.clone()))
} else {
None
};
let op = BackpropOp::new1(self, Op::ToDevice);
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: Arc::new(RwLock::new(storage)),
@ -1519,16 +1433,11 @@ impl Tensor {
/// any value, the dimension `t_a` must be equal to `i_a` if `i_a` is different from 1. If
/// `i_a` is equal to 1, any value can be used.
pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
let op = if self.track_op() {
Some(Op::Broadcast(self.clone()))
} else {
None
};
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
layout: self.layout.broadcast_as(shape)?,
op,
op: BackpropOp::new1(self, Op::Broadcast),
is_variable: false,
dtype: self.dtype,
device: self.device.clone(),
@ -1557,11 +1466,7 @@ impl Tensor {
} else {
let shape = self.shape();
let storage = self.storage().to_dtype(self.layout(), dtype)?;
let op = if self.track_op() {
Some(Op::ToDType(self.clone()))
} else {
None
};
let op = BackpropOp::new1(self, Op::ToDType);
Ok(from_storage(storage, shape.clone(), op, false))
}
}
@ -1576,11 +1481,7 @@ impl Tensor {
let mut storage = self.device().zeros(shape, self.dtype())?;
self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?;
let op = if self.track_op() {
Some(Op::Copy(self.clone()))
} else {
None
};
let op = BackpropOp::new1(self, Op::Copy);
Ok(from_storage(storage, shape.clone(), op, false))
}
}
@ -1612,11 +1513,7 @@ impl Tensor {
}
.bt());
}
let op = if self.track_op() {
Some(Op::Reshape(self.clone()))
} else {
None
};
let op = BackpropOp::new1(self, Op::Reshape);
if self.is_contiguous() {
let tensor_ = Tensor_ {
id: TensorId::new(),
@ -1820,12 +1717,7 @@ impl Tensor {
offsets.push(next_offset);
}
let shape = Shape::from(cat_dims);
let op = if args.iter().any(|arg| arg.as_ref().track_op()) {
let args: Vec<Tensor> = args.iter().map(|arg| arg.as_ref().clone()).collect();
Some(Op::Cat(args, 0))
} else {
None
};
let op = BackpropOp::new(args, |args| Op::Cat(args, 0));
let mut storage = device.zeros(&shape, dtype)?;
for (arg, &offset) in args.iter().zip(offsets.iter()) {
let arg = arg.as_ref();
@ -1865,11 +1757,7 @@ impl Tensor {
let (storage, shape) = self
.storage()
.custom_op1(self.layout(), c.as_ref().as_ref())?;
let op = if self.track_op() {
Some(Op::CustomOp1(self.clone(), c))
} else {
None
};
let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
Ok(from_storage(storage, shape, op, false))
}
@ -1885,11 +1773,7 @@ impl Tensor {
rhs.layout(),
c.as_ref().as_ref(),
)?;
let op = if self.track_op() {
Some(Op::CustomOp2(self.clone(), rhs.clone(), c))
} else {
None
};
let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone()));
Ok(from_storage(storage, shape, op, false))
}
@ -1907,11 +1791,9 @@ impl Tensor {
t3.layout(),
c.as_ref().as_ref(),
)?;
let op = if self.track_op() {
Some(Op::CustomOp3(self.clone(), t2.clone(), t3.clone(), c))
} else {
None
};
let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| {
Op::CustomOp3(t1, t2, t3, c.clone())
});
Ok(from_storage(storage, shape, op, false))
}