Add slice-scatter. (#927)

* Add slice-scatter.

* Add the op.

* Make transpose be a no-op when the dimensions are identical.

* Add the backprop.

* And add some gradient test.
This commit is contained in:
Laurent Mazare
2023-09-22 12:18:16 +01:00
committed by GitHub
parent a96878f235
commit 8601537e31
5 changed files with 142 additions and 1 deletions

View File

@ -1132,6 +1132,74 @@ impl Tensor {
Ok(from_storage(storage, self.shape(), op, false))
}
/// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension.
pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: usize, start: usize) -> Result<Self> {
let dim = dim.to_index(self.shape(), "slice-scatter")?;
if dim == 0 {
self.slice_scatter0(src, start)
} else {
// TODO: Maybe we want to add a more efficient implementation at some point.
self.transpose(0, dim)?
.slice_scatter0(&src.transpose(0, dim)?, start)?
.transpose(0, dim)
}
}
/// Embeds the values of the `src` tensor into the `self` tensor on the first dimension.
pub fn slice_scatter0(&self, src: &Self, start: usize) -> Result<Self> {
if self.dtype() != src.dtype() {
Err(Error::DTypeMismatchBinaryOp {
lhs: self.dtype(),
rhs: src.dtype(),
op: "slice-scatter",
}
.bt())?
}
if self.device().location() != src.device.location() {
Err(Error::DeviceMismatchBinaryOp {
lhs: self.device().location(),
rhs: src.device().location(),
op: "slice-scatter",
}
.bt())?
}
if self.rank() != src.rank() {
Err(Error::UnexpectedNumberOfDims {
expected: self.rank(),
got: src.rank(),
shape: src.shape().clone(),
}
.bt())?
}
let shape_ok =
self.dims()
.iter()
.zip(src.dims().iter())
.enumerate()
.all(|(dim_idx, (&d1, &d2))| {
if 0 == dim_idx {
d2 + start <= d1
} else {
d1 == d2
}
});
if !shape_ok {
Err(Error::ShapeMismatchBinaryOp {
op: "slice-scatter (self, src)",
lhs: self.shape().clone(),
rhs: src.shape().clone(),
})?
}
let mut storage = self.device().zeros(self.shape(), self.dtype())?;
self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?;
let offset = start * src.dims()[1..].iter().product::<usize>();
src.storage()
.copy_strided_src(&mut storage, offset, src.layout())?;
let op = BackpropOp::new2(self, src, |t1, t2| Op::SliceScatter0(t1, t2, start));
Ok(from_storage(storage, self.shape(), op, false))
}
/// Accumulate element from `source` at indexes `indexes` and add them to `self`.
pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
let dim = dim.to_index(self.shape(), "index-add")?;
@ -1548,6 +1616,9 @@ impl Tensor {
pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> {
let dim1 = dim1.to_index(self.shape(), "transpose")?;
let dim2 = dim2.to_index(self.shape(), "transpose")?;
if dim1 == dim2 {
return Ok(self.clone());
}
let op = BackpropOp::new1(self, |t| Op::Transpose(t, dim1, dim2));
let tensor_ = Tensor_ {
id: TensorId::new(),