mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add an abstract backprop op type (#240)
* Start adding the backprop op type. * More backprop ops. * Finish the backprop op.
This commit is contained in:
@ -58,7 +58,7 @@ pub enum UnaryOp {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub(crate) enum Op {
|
pub enum Op {
|
||||||
Binary(Tensor, Tensor, BinaryOp),
|
Binary(Tensor, Tensor, BinaryOp),
|
||||||
Unary(Tensor, UnaryOp),
|
Unary(Tensor, UnaryOp),
|
||||||
Cmp(Tensor, CmpOp),
|
Cmp(Tensor, CmpOp),
|
||||||
@ -512,3 +512,63 @@ impl UnaryOpT for Relu {
|
|||||||
v
|
v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// `BackpropOp` is a wrapper around `Option<Op>`. The main goal is to ensure that dependencies are
|
||||||
|
/// properly checked when creating a new value
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct BackpropOp(Option<Op>);
|
||||||
|
|
||||||
|
impl BackpropOp {
|
||||||
|
pub(crate) fn none() -> Self {
|
||||||
|
BackpropOp(None)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn new1(arg: &Tensor, f: impl Fn(Tensor) -> Op) -> Self {
|
||||||
|
let op = if arg.track_op() {
|
||||||
|
Some(f(arg.clone()))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
Self(op)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn new2(arg1: &Tensor, arg2: &Tensor, f: impl Fn(Tensor, Tensor) -> Op) -> Self {
|
||||||
|
let op = if arg1.track_op() || arg2.track_op() {
|
||||||
|
Some(f(arg1.clone(), arg2.clone()))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
Self(op)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn new3(
|
||||||
|
arg1: &Tensor,
|
||||||
|
arg2: &Tensor,
|
||||||
|
arg3: &Tensor,
|
||||||
|
f: impl Fn(Tensor, Tensor, Tensor) -> Op,
|
||||||
|
) -> Self {
|
||||||
|
let op = if arg1.track_op() || arg2.track_op() || arg3.track_op() {
|
||||||
|
Some(f(arg1.clone(), arg2.clone(), arg3.clone()))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
Self(op)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn new<A: AsRef<Tensor>>(args: &[A], f: impl Fn(Vec<Tensor>) -> Op) -> Self {
|
||||||
|
let op = if args.iter().any(|arg| arg.as_ref().track_op()) {
|
||||||
|
let args: Vec<Tensor> = args.iter().map(|arg| arg.as_ref().clone()).collect();
|
||||||
|
Some(f(args))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
Self(op)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::ops::Deref for BackpropOp {
|
||||||
|
type Target = Option<Op>;
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
use crate::backend::{BackendDevice, BackendStorage};
|
use crate::backend::{BackendDevice, BackendStorage};
|
||||||
use crate::op::{BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp};
|
use crate::op::{
|
||||||
|
BackpropOp, BinaryOp, CmpOp, CustomOp1, CustomOp2, CustomOp3, Op, ReduceOp, UnaryOp,
|
||||||
|
};
|
||||||
use crate::shape::{Dim, Dims};
|
use crate::shape::{Dim, Dims};
|
||||||
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
@ -33,7 +35,7 @@ pub struct Tensor_ {
|
|||||||
// that's tricky to encode in the current setup.
|
// that's tricky to encode in the current setup.
|
||||||
storage: Arc<RwLock<Storage>>,
|
storage: Arc<RwLock<Storage>>,
|
||||||
layout: Layout,
|
layout: Layout,
|
||||||
op: Option<Op>,
|
op: BackpropOp,
|
||||||
is_variable: bool,
|
is_variable: bool,
|
||||||
dtype: DType,
|
dtype: DType,
|
||||||
device: Device,
|
device: Device,
|
||||||
@ -79,11 +81,7 @@ macro_rules! unary_op {
|
|||||||
let storage = self
|
let storage = self
|
||||||
.storage()
|
.storage()
|
||||||
.unary_impl::<crate::op::$op_name>(self.layout())?;
|
.unary_impl::<crate::op::$op_name>(self.layout())?;
|
||||||
let op = if self.track_op() {
|
let op = BackpropOp::new1(self, |s| Op::Unary(s, UnaryOp::$op_name));
|
||||||
Some(Op::Unary(self.clone(), UnaryOp::$op_name))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
Ok(from_storage(storage, shape.clone(), op, false))
|
Ok(from_storage(storage, shape.clone(), op, false))
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -98,11 +96,7 @@ macro_rules! binary_op {
|
|||||||
self.layout(),
|
self.layout(),
|
||||||
rhs.layout(),
|
rhs.layout(),
|
||||||
)?;
|
)?;
|
||||||
let op = if self.track_op() || rhs.track_op() {
|
let op = BackpropOp::new2(self, rhs, |t1, t2| Op::Binary(t1, t2, BinaryOp::$op_name));
|
||||||
Some(Op::Binary(self.clone(), rhs.clone(), BinaryOp::$op_name))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
Ok(from_storage(storage, shape.clone(), op, false))
|
Ok(from_storage(storage, shape.clone(), op, false))
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -131,7 +125,7 @@ macro_rules! broadcast_binary_op {
|
|||||||
fn from_storage<S: Into<Shape>>(
|
fn from_storage<S: Into<Shape>>(
|
||||||
storage: Storage,
|
storage: Storage,
|
||||||
shape: S,
|
shape: S,
|
||||||
op: Option<Op>,
|
op: BackpropOp,
|
||||||
is_variable: bool,
|
is_variable: bool,
|
||||||
) -> Tensor {
|
) -> Tensor {
|
||||||
let dtype = storage.dtype();
|
let dtype = storage.dtype();
|
||||||
@ -155,13 +149,14 @@ impl Tensor {
|
|||||||
device: &Device,
|
device: &Device,
|
||||||
is_variable: bool,
|
is_variable: bool,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
|
let none = BackpropOp::none();
|
||||||
if is_variable {
|
if is_variable {
|
||||||
let shape = shape.into();
|
let shape = shape.into();
|
||||||
let storage = device.ones(&shape, dtype)?;
|
let storage = device.ones(&shape, dtype)?;
|
||||||
Ok(from_storage(storage, shape, None, is_variable))
|
Ok(from_storage(storage, shape, none, is_variable))
|
||||||
} else {
|
} else {
|
||||||
let storage = device.ones(&crate::shape::SCALAR, dtype)?;
|
let storage = device.ones(&crate::shape::SCALAR, dtype)?;
|
||||||
from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape)
|
from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -199,13 +194,14 @@ impl Tensor {
|
|||||||
device: &Device,
|
device: &Device,
|
||||||
is_variable: bool,
|
is_variable: bool,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
|
let none = BackpropOp::none();
|
||||||
if is_variable {
|
if is_variable {
|
||||||
let shape = shape.into();
|
let shape = shape.into();
|
||||||
let storage = device.zeros(&shape, dtype)?;
|
let storage = device.zeros(&shape, dtype)?;
|
||||||
Ok(from_storage(storage, shape, None, is_variable))
|
Ok(from_storage(storage, shape, none, is_variable))
|
||||||
} else {
|
} else {
|
||||||
let storage = device.zeros(&crate::shape::SCALAR, dtype)?;
|
let storage = device.zeros(&crate::shape::SCALAR, dtype)?;
|
||||||
from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape)
|
from_storage(storage, crate::shape::SCALAR, none, is_variable).broadcast_as(shape)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -246,7 +242,8 @@ impl Tensor {
|
|||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let s = s.into();
|
let s = s.into();
|
||||||
let storage = device.rand_uniform(&s, dtype, lo, up)?;
|
let storage = device.rand_uniform(&s, dtype, lo, up)?;
|
||||||
Ok(from_storage(storage, s, None, is_variable))
|
let none = BackpropOp::none();
|
||||||
|
Ok(from_storage(storage, s, none, is_variable))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a new tensor initialized with values sampled uniformly between `lo` and `up`.
|
/// Creates a new tensor initialized with values sampled uniformly between `lo` and `up`.
|
||||||
@ -270,7 +267,8 @@ impl Tensor {
|
|||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let s = s.into();
|
let s = s.into();
|
||||||
let storage = device.rand_normal(&s, dtype, mean, std)?;
|
let storage = device.rand_normal(&s, dtype, mean, std)?;
|
||||||
Ok(from_storage(storage, s, None, is_variable))
|
let none = BackpropOp::none();
|
||||||
|
Ok(from_storage(storage, s, none, is_variable))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a new tensor initialized with values sampled from a normal distribution with the
|
/// Creates a new tensor initialized with values sampled from a normal distribution with the
|
||||||
@ -297,7 +295,8 @@ impl Tensor {
|
|||||||
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
|
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
|
||||||
}
|
}
|
||||||
let storage = device.storage(array)?;
|
let storage = device.storage(array)?;
|
||||||
Ok(from_storage(storage, shape, None, is_variable))
|
let none = BackpropOp::none();
|
||||||
|
Ok(from_storage(storage, shape, none, is_variable))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a new tensor on the specified device using the content and shape of the input.
|
/// Creates a new tensor on the specified device using the content and shape of the input.
|
||||||
@ -352,7 +351,8 @@ impl Tensor {
|
|||||||
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
|
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
|
||||||
}
|
}
|
||||||
let storage = device.storage_owned(data)?;
|
let storage = device.storage_owned(data)?;
|
||||||
Ok(from_storage(storage, shape, None, is_variable))
|
let none = BackpropOp::none();
|
||||||
|
Ok(from_storage(storage, shape, none, is_variable))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a new tensor initialized with values from the input vector. The number of elements
|
/// Creates a new tensor initialized with values from the input vector. The number of elements
|
||||||
@ -500,26 +500,14 @@ impl Tensor {
|
|||||||
/// ```
|
/// ```
|
||||||
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
|
pub fn affine(&self, mul: f64, add: f64) -> Result<Self> {
|
||||||
let storage = self.storage().affine(self.layout(), mul, add)?;
|
let storage = self.storage().affine(self.layout(), mul, add)?;
|
||||||
let op = if self.track_op() {
|
let op = BackpropOp::new1(self, |arg| Op::Affine { arg, mul, add });
|
||||||
Some(Op::Affine {
|
|
||||||
arg: self.clone(),
|
|
||||||
mul,
|
|
||||||
add,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
Ok(from_storage(storage, self.shape(), op, false))
|
Ok(from_storage(storage, self.shape(), op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor.
|
/// Applies the Exponential Linear Unit (ELU) function on each element of the input tensor.
|
||||||
pub fn elu(&self, alpha: f64) -> Result<Self> {
|
pub fn elu(&self, alpha: f64) -> Result<Self> {
|
||||||
let storage = self.storage().elu(self.layout(), alpha)?;
|
let storage = self.storage().elu(self.layout(), alpha)?;
|
||||||
let op = if self.track_op() {
|
let op = BackpropOp::new1(self, |t| Op::Elu(t, alpha));
|
||||||
Some(Op::Elu(self.clone(), alpha))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
Ok(from_storage(storage, self.shape(), op, false))
|
Ok(from_storage(storage, self.shape(), op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -554,11 +542,7 @@ impl Tensor {
|
|||||||
if start == 0 && dims[dim] == len {
|
if start == 0 && dims[dim] == len {
|
||||||
Ok(self.clone())
|
Ok(self.clone())
|
||||||
} else {
|
} else {
|
||||||
let op = if self.track_op() {
|
let op = BackpropOp::new1(self, |t| Op::Narrow(t, dim, start, len));
|
||||||
Some(Op::Narrow(self.clone(), dim, start, len))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
let layout = self.layout().narrow(dim, start, len)?;
|
let layout = self.layout().narrow(dim, start, len)?;
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
@ -602,11 +586,7 @@ impl Tensor {
|
|||||||
let mut storage = self.storage().unary_impl::<crate::op::Exp>(self.layout())?;
|
let mut storage = self.storage().unary_impl::<crate::op::Exp>(self.layout())?;
|
||||||
// The resulting storage is contiguous.
|
// The resulting storage is contiguous.
|
||||||
storage.divide_by_sum_over_dim(shape, dim)?;
|
storage.divide_by_sum_over_dim(shape, dim)?;
|
||||||
let op = if self.track_op() {
|
let op = BackpropOp::new1(self, |arg| Op::Softmax(arg, dim));
|
||||||
Some(Op::Softmax(self.clone(), dim))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
Ok(from_storage(storage, shape.clone(), op, false))
|
Ok(from_storage(storage, shape.clone(), op, false))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -638,11 +618,7 @@ impl Tensor {
|
|||||||
let storage = self.storage().reduce_op(op, self.layout(), &[dim])?;
|
let storage = self.storage().reduce_op(op, self.layout(), &[dim])?;
|
||||||
let mut dims = self.dims().to_vec();
|
let mut dims = self.dims().to_vec();
|
||||||
dims[dim] = 1;
|
dims[dim] = 1;
|
||||||
let op = if self.track_op() {
|
let op = BackpropOp::new1(self, |arg| Op::Reduce(arg, op, dims.to_vec()));
|
||||||
Some(Op::Reduce(self.clone(), op, dims.to_vec()))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
let res = from_storage(storage, dims, op, false);
|
let res = from_storage(storage, dims, op, false);
|
||||||
if keepdim {
|
if keepdim {
|
||||||
Ok(res)
|
Ok(res)
|
||||||
@ -660,11 +636,7 @@ impl Tensor {
|
|||||||
for &sum_dim in sum_dims.iter() {
|
for &sum_dim in sum_dims.iter() {
|
||||||
dims[sum_dim] = 1
|
dims[sum_dim] = 1
|
||||||
}
|
}
|
||||||
let op = if self.track_op() {
|
let op = BackpropOp::new1(self, |a| Op::Reduce(a, ReduceOp::Sum, dims.to_vec()));
|
||||||
Some(Op::Reduce(self.clone(), ReduceOp::Sum, dims.to_vec()))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
let sum = from_storage(storage, dims, op, false);
|
let sum = from_storage(storage, dims, op, false);
|
||||||
if keepdim {
|
if keepdim {
|
||||||
Ok(sum)
|
Ok(sum)
|
||||||
@ -738,11 +710,7 @@ impl Tensor {
|
|||||||
let storage = self
|
let storage = self
|
||||||
.storage()
|
.storage()
|
||||||
.cmp(op, &rhs.storage(), self.layout(), rhs.layout())?;
|
.cmp(op, &rhs.storage(), self.layout(), rhs.layout())?;
|
||||||
let op = if self.track_op() {
|
let op = BackpropOp::new1(self, |a| Op::Cmp(a, op));
|
||||||
Some(Op::Cmp(self.clone(), op))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
Ok(from_storage(storage, shape.dims(), op, false))
|
Ok(from_storage(storage, shape.dims(), op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -807,16 +775,12 @@ impl Tensor {
|
|||||||
let storage =
|
let storage =
|
||||||
self.storage()
|
self.storage()
|
||||||
.conv1d(self.layout(), &kernel.storage(), kernel.layout(), ¶ms)?;
|
.conv1d(self.layout(), &kernel.storage(), kernel.layout(), ¶ms)?;
|
||||||
let op = if self.track_op() || kernel.track_op() {
|
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D {
|
||||||
Some(Op::Conv1D {
|
arg,
|
||||||
arg: self.clone(),
|
kernel,
|
||||||
kernel: kernel.clone(),
|
|
||||||
padding,
|
padding,
|
||||||
stride,
|
stride,
|
||||||
})
|
});
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
let out_dims = params.out_dims();
|
let out_dims = params.out_dims();
|
||||||
Ok(from_storage(storage, out_dims, op, false))
|
Ok(from_storage(storage, out_dims, op, false))
|
||||||
}
|
}
|
||||||
@ -867,11 +831,7 @@ impl Tensor {
|
|||||||
self.layout(),
|
self.layout(),
|
||||||
rhs.layout(),
|
rhs.layout(),
|
||||||
)?;
|
)?;
|
||||||
let op = if self.track_op() || rhs.track_op() {
|
let op = BackpropOp::new2(self, rhs, Op::Matmul);
|
||||||
Some(Op::Matmul(self.clone(), rhs.clone()))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
Ok(from_storage(storage, c_shape, op, false))
|
Ok(from_storage(storage, c_shape, op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -888,15 +848,7 @@ impl Tensor {
|
|||||||
&on_false.storage(),
|
&on_false.storage(),
|
||||||
on_false.layout(),
|
on_false.layout(),
|
||||||
)?;
|
)?;
|
||||||
let op = if self.track_op() || on_true.track_op() || on_false.track_op() {
|
let op = BackpropOp::new3(self, on_true, on_false, Op::WhereCond);
|
||||||
Some(Op::WhereCond(
|
|
||||||
self.clone(),
|
|
||||||
on_true.clone(),
|
|
||||||
on_false.clone(),
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
Ok(from_storage(storage, shape, op, false))
|
Ok(from_storage(storage, shape, op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -937,11 +889,7 @@ impl Tensor {
|
|||||||
.storage()
|
.storage()
|
||||||
.embedding(ids.layout(), &rhs.storage(), rhs.layout())?;
|
.embedding(ids.layout(), &rhs.storage(), rhs.layout())?;
|
||||||
let shape: Shape = (seq_len, hidden_size).into();
|
let shape: Shape = (seq_len, hidden_size).into();
|
||||||
let op = if ids.track_op() || rhs.track_op() {
|
let op = BackpropOp::new2(ids, rhs, Op::Embedding);
|
||||||
Some(Op::Embedding(ids.clone(), rhs.clone()))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
Ok(from_storage(storage, shape, op, false))
|
Ok(from_storage(storage, shape, op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -983,16 +931,9 @@ impl Tensor {
|
|||||||
source.layout(),
|
source.layout(),
|
||||||
dim,
|
dim,
|
||||||
)?;
|
)?;
|
||||||
let op = if indexes.track_op() || self.track_op() {
|
let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
|
||||||
Some(Op::ScatterAdd(
|
Op::ScatterAdd(t1, t2, t3, dim)
|
||||||
self.clone(),
|
});
|
||||||
indexes.clone(),
|
|
||||||
source.clone(),
|
|
||||||
dim,
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
Ok(from_storage(storage, self.shape(), op, false))
|
Ok(from_storage(storage, self.shape(), op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1038,16 +979,9 @@ impl Tensor {
|
|||||||
source.layout(),
|
source.layout(),
|
||||||
dim,
|
dim,
|
||||||
)?;
|
)?;
|
||||||
let op = if indexes.track_op() || self.track_op() {
|
let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
|
||||||
Some(Op::IndexAdd(
|
Op::IndexAdd(t1, t2, t3, dim)
|
||||||
self.clone(),
|
});
|
||||||
indexes.clone(),
|
|
||||||
source.clone(),
|
|
||||||
dim,
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
Ok(from_storage(storage, self.shape(), op, false))
|
Ok(from_storage(storage, self.shape(), op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1077,11 +1011,7 @@ impl Tensor {
|
|||||||
let storage =
|
let storage =
|
||||||
self.storage()
|
self.storage()
|
||||||
.gather(self.layout(), &indexes.storage(), indexes.layout(), dim)?;
|
.gather(self.layout(), &indexes.storage(), indexes.layout(), dim)?;
|
||||||
let op = if indexes.track_op() || self.track_op() {
|
let op = BackpropOp::new2(self, indexes, |t1, t2| Op::Gather(t1, t2, dim));
|
||||||
Some(Op::Gather(self.clone(), indexes.clone(), dim))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
Ok(from_storage(storage, indexes.shape(), op, false))
|
Ok(from_storage(storage, indexes.shape(), op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1104,11 +1034,7 @@ impl Tensor {
|
|||||||
)?;
|
)?;
|
||||||
let mut dims = self.dims().to_vec();
|
let mut dims = self.dims().to_vec();
|
||||||
dims[dim] = indexes_len;
|
dims[dim] = indexes_len;
|
||||||
let op = if indexes.track_op() || self.track_op() {
|
let op = BackpropOp::new2(self, indexes, |t1, t2| Op::IndexSelect(t1, t2, dim));
|
||||||
Some(Op::IndexSelect(self.clone(), indexes.clone(), dim))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
Ok(from_storage(storage, dims, op, false))
|
Ok(from_storage(storage, dims, op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1404,11 +1330,7 @@ impl Tensor {
|
|||||||
pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> {
|
pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> {
|
||||||
let dim1 = dim1.to_index(self.shape(), "transpose")?;
|
let dim1 = dim1.to_index(self.shape(), "transpose")?;
|
||||||
let dim2 = dim2.to_index(self.shape(), "transpose")?;
|
let dim2 = dim2.to_index(self.shape(), "transpose")?;
|
||||||
let op = if self.track_op() {
|
let op = BackpropOp::new1(self, |t| Op::Transpose(t, dim1, dim2));
|
||||||
Some(Op::Transpose(self.clone(), dim1, dim2))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
storage: self.storage.clone(),
|
storage: self.storage.clone(),
|
||||||
@ -1434,11 +1356,7 @@ impl Tensor {
|
|||||||
/// Compared to clone, this copies the actual storage but may fail because of running out of
|
/// Compared to clone, this copies the actual storage but may fail because of running out of
|
||||||
/// memory.
|
/// memory.
|
||||||
pub fn copy(&self) -> Result<Tensor> {
|
pub fn copy(&self) -> Result<Tensor> {
|
||||||
let op = if self.track_op() {
|
let op = BackpropOp::new1(self, Op::Copy);
|
||||||
Some(Op::Copy(self.clone()))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
storage: Arc::new(RwLock::new(self.storage().try_clone(self.layout())?)),
|
storage: Arc::new(RwLock::new(self.storage().try_clone(self.layout())?)),
|
||||||
@ -1458,7 +1376,7 @@ impl Tensor {
|
|||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
storage: self.storage.clone(),
|
storage: self.storage.clone(),
|
||||||
layout: self.layout.clone(),
|
layout: self.layout.clone(),
|
||||||
op: None,
|
op: BackpropOp::none(),
|
||||||
is_variable: false,
|
is_variable: false,
|
||||||
dtype: self.dtype,
|
dtype: self.dtype,
|
||||||
device: self.device.clone(),
|
device: self.device.clone(),
|
||||||
@ -1484,11 +1402,7 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
|
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
|
||||||
};
|
};
|
||||||
let op = if self.track_op() {
|
let op = BackpropOp::new1(self, Op::ToDevice);
|
||||||
Some(Op::ToDevice(self.clone()))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
storage: Arc::new(RwLock::new(storage)),
|
storage: Arc::new(RwLock::new(storage)),
|
||||||
@ -1519,16 +1433,11 @@ impl Tensor {
|
|||||||
/// any value, the dimension `t_a` must be equal to `i_a` if `i_a` is different from 1. If
|
/// any value, the dimension `t_a` must be equal to `i_a` if `i_a` is different from 1. If
|
||||||
/// `i_a` is equal to 1, any value can be used.
|
/// `i_a` is equal to 1, any value can be used.
|
||||||
pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
|
pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
|
||||||
let op = if self.track_op() {
|
|
||||||
Some(Op::Broadcast(self.clone()))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
storage: self.storage.clone(),
|
storage: self.storage.clone(),
|
||||||
layout: self.layout.broadcast_as(shape)?,
|
layout: self.layout.broadcast_as(shape)?,
|
||||||
op,
|
op: BackpropOp::new1(self, Op::Broadcast),
|
||||||
is_variable: false,
|
is_variable: false,
|
||||||
dtype: self.dtype,
|
dtype: self.dtype,
|
||||||
device: self.device.clone(),
|
device: self.device.clone(),
|
||||||
@ -1557,11 +1466,7 @@ impl Tensor {
|
|||||||
} else {
|
} else {
|
||||||
let shape = self.shape();
|
let shape = self.shape();
|
||||||
let storage = self.storage().to_dtype(self.layout(), dtype)?;
|
let storage = self.storage().to_dtype(self.layout(), dtype)?;
|
||||||
let op = if self.track_op() {
|
let op = BackpropOp::new1(self, Op::ToDType);
|
||||||
Some(Op::ToDType(self.clone()))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
Ok(from_storage(storage, shape.clone(), op, false))
|
Ok(from_storage(storage, shape.clone(), op, false))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1576,11 +1481,7 @@ impl Tensor {
|
|||||||
let mut storage = self.device().zeros(shape, self.dtype())?;
|
let mut storage = self.device().zeros(shape, self.dtype())?;
|
||||||
self.storage()
|
self.storage()
|
||||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||||
let op = if self.track_op() {
|
let op = BackpropOp::new1(self, Op::Copy);
|
||||||
Some(Op::Copy(self.clone()))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
Ok(from_storage(storage, shape.clone(), op, false))
|
Ok(from_storage(storage, shape.clone(), op, false))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1612,11 +1513,7 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
.bt());
|
.bt());
|
||||||
}
|
}
|
||||||
let op = if self.track_op() {
|
let op = BackpropOp::new1(self, Op::Reshape);
|
||||||
Some(Op::Reshape(self.clone()))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
if self.is_contiguous() {
|
if self.is_contiguous() {
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
@ -1820,12 +1717,7 @@ impl Tensor {
|
|||||||
offsets.push(next_offset);
|
offsets.push(next_offset);
|
||||||
}
|
}
|
||||||
let shape = Shape::from(cat_dims);
|
let shape = Shape::from(cat_dims);
|
||||||
let op = if args.iter().any(|arg| arg.as_ref().track_op()) {
|
let op = BackpropOp::new(args, |args| Op::Cat(args, 0));
|
||||||
let args: Vec<Tensor> = args.iter().map(|arg| arg.as_ref().clone()).collect();
|
|
||||||
Some(Op::Cat(args, 0))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
let mut storage = device.zeros(&shape, dtype)?;
|
let mut storage = device.zeros(&shape, dtype)?;
|
||||||
for (arg, &offset) in args.iter().zip(offsets.iter()) {
|
for (arg, &offset) in args.iter().zip(offsets.iter()) {
|
||||||
let arg = arg.as_ref();
|
let arg = arg.as_ref();
|
||||||
@ -1865,11 +1757,7 @@ impl Tensor {
|
|||||||
let (storage, shape) = self
|
let (storage, shape) = self
|
||||||
.storage()
|
.storage()
|
||||||
.custom_op1(self.layout(), c.as_ref().as_ref())?;
|
.custom_op1(self.layout(), c.as_ref().as_ref())?;
|
||||||
let op = if self.track_op() {
|
let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
|
||||||
Some(Op::CustomOp1(self.clone(), c))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
Ok(from_storage(storage, shape, op, false))
|
Ok(from_storage(storage, shape, op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1885,11 +1773,7 @@ impl Tensor {
|
|||||||
rhs.layout(),
|
rhs.layout(),
|
||||||
c.as_ref().as_ref(),
|
c.as_ref().as_ref(),
|
||||||
)?;
|
)?;
|
||||||
let op = if self.track_op() {
|
let op = BackpropOp::new2(self, rhs, |t1, t2| Op::CustomOp2(t1, t2, c.clone()));
|
||||||
Some(Op::CustomOp2(self.clone(), rhs.clone(), c))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
Ok(from_storage(storage, shape, op, false))
|
Ok(from_storage(storage, shape, op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1907,11 +1791,9 @@ impl Tensor {
|
|||||||
t3.layout(),
|
t3.layout(),
|
||||||
c.as_ref().as_ref(),
|
c.as_ref().as_ref(),
|
||||||
)?;
|
)?;
|
||||||
let op = if self.track_op() {
|
let op = BackpropOp::new3(self, t2, t3, |t1, t2, t3| {
|
||||||
Some(Op::CustomOp3(self.clone(), t2.clone(), t3.clone(), c))
|
Op::CustomOp3(t1, t2, t3, c.clone())
|
||||||
} else {
|
});
|
||||||
None
|
|
||||||
};
|
|
||||||
Ok(from_storage(storage, shape, op, false))
|
Ok(from_storage(storage, shape, op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user