mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add slice-scatter. (#927)
* Add slice-scatter. * Add the op. * Make transpose be a no-op when the dimensions are identical. * Add the backprop. * And add some gradient test.
This commit is contained in:
@ -69,7 +69,8 @@ impl Tensor {
|
|||||||
| Op::Binary(lhs, rhs, _)
|
| Op::Binary(lhs, rhs, _)
|
||||||
| Op::Gather(lhs, rhs, _)
|
| Op::Gather(lhs, rhs, _)
|
||||||
| Op::IndexSelect(lhs, rhs, _)
|
| Op::IndexSelect(lhs, rhs, _)
|
||||||
| Op::Matmul(lhs, rhs) => {
|
| Op::Matmul(lhs, rhs)
|
||||||
|
| Op::SliceScatter0(lhs, rhs, _) => {
|
||||||
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
||||||
track_grad |= tg;
|
track_grad |= tg;
|
||||||
let (tg, nodes) = walk(rhs, nodes, already_seen);
|
let (tg, nodes) = walk(rhs, nodes, already_seen);
|
||||||
@ -270,6 +271,15 @@ impl Tensor {
|
|||||||
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
|
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
|
||||||
op: "upsample-nearest2d",
|
op: "upsample-nearest2d",
|
||||||
})?,
|
})?,
|
||||||
|
Op::SliceScatter0(lhs, rhs, start_rhs) => {
|
||||||
|
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||||
|
let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
|
||||||
|
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||||
|
|
||||||
|
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||||
|
let lhs_grad = grad.slice_scatter0(&rhs.zeros_like()?, *start_rhs)?;
|
||||||
|
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?
|
||||||
|
}
|
||||||
Op::Gather(arg, indexes, dim) => {
|
Op::Gather(arg, indexes, dim) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
|
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
|
||||||
|
@ -133,6 +133,7 @@ pub enum Op {
|
|||||||
Copy(Tensor),
|
Copy(Tensor),
|
||||||
Broadcast(Tensor),
|
Broadcast(Tensor),
|
||||||
Narrow(Tensor, usize, usize, usize),
|
Narrow(Tensor, usize, usize, usize),
|
||||||
|
SliceScatter0(Tensor, Tensor, usize),
|
||||||
Reshape(Tensor),
|
Reshape(Tensor),
|
||||||
ToDevice(Tensor),
|
ToDevice(Tensor),
|
||||||
Transpose(Tensor, usize, usize),
|
Transpose(Tensor, usize, usize),
|
||||||
|
@ -1132,6 +1132,74 @@ impl Tensor {
|
|||||||
Ok(from_storage(storage, self.shape(), op, false))
|
Ok(from_storage(storage, self.shape(), op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension.
|
||||||
|
pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: usize, start: usize) -> Result<Self> {
|
||||||
|
let dim = dim.to_index(self.shape(), "slice-scatter")?;
|
||||||
|
if dim == 0 {
|
||||||
|
self.slice_scatter0(src, start)
|
||||||
|
} else {
|
||||||
|
// TODO: Maybe we want to add a more efficient implementation at some point.
|
||||||
|
self.transpose(0, dim)?
|
||||||
|
.slice_scatter0(&src.transpose(0, dim)?, start)?
|
||||||
|
.transpose(0, dim)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Embeds the values of the `src` tensor into the `self` tensor on the first dimension.
|
||||||
|
pub fn slice_scatter0(&self, src: &Self, start: usize) -> Result<Self> {
|
||||||
|
if self.dtype() != src.dtype() {
|
||||||
|
Err(Error::DTypeMismatchBinaryOp {
|
||||||
|
lhs: self.dtype(),
|
||||||
|
rhs: src.dtype(),
|
||||||
|
op: "slice-scatter",
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
if self.device().location() != src.device.location() {
|
||||||
|
Err(Error::DeviceMismatchBinaryOp {
|
||||||
|
lhs: self.device().location(),
|
||||||
|
rhs: src.device().location(),
|
||||||
|
op: "slice-scatter",
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
if self.rank() != src.rank() {
|
||||||
|
Err(Error::UnexpectedNumberOfDims {
|
||||||
|
expected: self.rank(),
|
||||||
|
got: src.rank(),
|
||||||
|
shape: src.shape().clone(),
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
let shape_ok =
|
||||||
|
self.dims()
|
||||||
|
.iter()
|
||||||
|
.zip(src.dims().iter())
|
||||||
|
.enumerate()
|
||||||
|
.all(|(dim_idx, (&d1, &d2))| {
|
||||||
|
if 0 == dim_idx {
|
||||||
|
d2 + start <= d1
|
||||||
|
} else {
|
||||||
|
d1 == d2
|
||||||
|
}
|
||||||
|
});
|
||||||
|
if !shape_ok {
|
||||||
|
Err(Error::ShapeMismatchBinaryOp {
|
||||||
|
op: "slice-scatter (self, src)",
|
||||||
|
lhs: self.shape().clone(),
|
||||||
|
rhs: src.shape().clone(),
|
||||||
|
})?
|
||||||
|
}
|
||||||
|
let mut storage = self.device().zeros(self.shape(), self.dtype())?;
|
||||||
|
self.storage()
|
||||||
|
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||||
|
let offset = start * src.dims()[1..].iter().product::<usize>();
|
||||||
|
src.storage()
|
||||||
|
.copy_strided_src(&mut storage, offset, src.layout())?;
|
||||||
|
let op = BackpropOp::new2(self, src, |t1, t2| Op::SliceScatter0(t1, t2, start));
|
||||||
|
Ok(from_storage(storage, self.shape(), op, false))
|
||||||
|
}
|
||||||
|
|
||||||
/// Accumulate element from `source` at indexes `indexes` and add them to `self`.
|
/// Accumulate element from `source` at indexes `indexes` and add them to `self`.
|
||||||
pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||||
let dim = dim.to_index(self.shape(), "index-add")?;
|
let dim = dim.to_index(self.shape(), "index-add")?;
|
||||||
@ -1548,6 +1616,9 @@ impl Tensor {
|
|||||||
pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> {
|
pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> {
|
||||||
let dim1 = dim1.to_index(self.shape(), "transpose")?;
|
let dim1 = dim1.to_index(self.shape(), "transpose")?;
|
||||||
let dim2 = dim2.to_index(self.shape(), "transpose")?;
|
let dim2 = dim2.to_index(self.shape(), "transpose")?;
|
||||||
|
if dim1 == dim2 {
|
||||||
|
return Ok(self.clone());
|
||||||
|
}
|
||||||
let op = BackpropOp::new1(self, |t| Op::Transpose(t, dim1, dim2));
|
let op = BackpropOp::new1(self, |t| Op::Transpose(t, dim1, dim2));
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
|
@ -218,6 +218,22 @@ fn binary_grad(device: &Device) -> Result<()> {
|
|||||||
let grad_x = grads.get(x).context("no grad for x")?;
|
let grad_x = grads.get(x).context("no grad for x")?;
|
||||||
assert_eq!(y.to_vec1::<f32>()?, [3., 1., -4., -1.]);
|
assert_eq!(y.to_vec1::<f32>()?, [3., 1., -4., -1.]);
|
||||||
assert_eq!(grad_x.to_vec1::<f32>()?, [1., 1., 1., 1.]);
|
assert_eq!(grad_x.to_vec1::<f32>()?, [1., 1., 1., 1.]);
|
||||||
|
|
||||||
|
let x_var = Var::new(&[3f32, 1., -4., -1., 5., 9.], device)?;
|
||||||
|
let x = x_var.as_tensor();
|
||||||
|
let y_var = Var::new(&[2f32, 7., 1.], device)?;
|
||||||
|
let y = y_var.as_tensor();
|
||||||
|
|
||||||
|
let ss = x
|
||||||
|
.reshape((2, 3))?
|
||||||
|
.slice_scatter0(&y.reshape((1, 3))?, 1)?
|
||||||
|
.sqr()?;
|
||||||
|
let grads = ss.backward()?;
|
||||||
|
let grad_x = grads.get(x).context("no grad for x")?;
|
||||||
|
let grad_y = grads.get(y).context("no grad for y")?;
|
||||||
|
assert_eq!(ss.to_vec2::<f32>()?, [[9., 1., 16.], [4., 49., 1.]]);
|
||||||
|
assert_eq!(grad_x.to_vec1::<f32>()?, [6.0, 2.0, -8.0, 0.0, 0.0, 0.0]);
|
||||||
|
assert_eq!(grad_y.to_vec1::<f32>()?, [4.0, 14.0, 2.0]);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -674,6 +674,48 @@ fn index_add(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn slice_scatter(device: &Device) -> Result<()> {
|
||||||
|
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
|
||||||
|
assert_eq!(
|
||||||
|
t.to_vec2::<f32>()?,
|
||||||
|
&[
|
||||||
|
[0.0, 1.0, 2.0],
|
||||||
|
[3.0, 4.0, 5.0],
|
||||||
|
[6.0, 7.0, 8.0],
|
||||||
|
[9.0, 10.0, 11.0]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
let src = Tensor::arange(100f32, 106f32, device)?.reshape((2, 3))?;
|
||||||
|
assert_eq!(
|
||||||
|
t.slice_scatter0(&src, 0)?.to_vec2::<f32>()?,
|
||||||
|
&[
|
||||||
|
[100.0, 101.0, 102.0],
|
||||||
|
[103.0, 104.0, 105.0],
|
||||||
|
[6.0, 7.0, 8.0],
|
||||||
|
[9.0, 10.0, 11.0]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
t.slice_scatter0(&src, 1)?.to_vec2::<f32>()?,
|
||||||
|
&[
|
||||||
|
[0.0, 1.0, 2.0],
|
||||||
|
[100.0, 101.0, 102.0],
|
||||||
|
[103.0, 104.0, 105.0],
|
||||||
|
[9.0, 10.0, 11.0]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
t.slice_scatter0(&src, 2)?.to_vec2::<f32>()?,
|
||||||
|
&[
|
||||||
|
[0.0, 1.0, 2.0],
|
||||||
|
[3.0, 4.0, 5.0],
|
||||||
|
[100.0, 101.0, 102.0],
|
||||||
|
[103.0, 104.0, 105.0],
|
||||||
|
]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn scatter_add(device: &Device) -> Result<()> {
|
fn scatter_add(device: &Device) -> Result<()> {
|
||||||
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
|
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -946,6 +988,7 @@ test_device!(index_select, index_select_cpu, index_select_gpu);
|
|||||||
test_device!(index_add, index_add_cpu, index_add_gpu);
|
test_device!(index_add, index_add_cpu, index_add_gpu);
|
||||||
test_device!(gather, gather_cpu, gather_gpu);
|
test_device!(gather, gather_cpu, gather_gpu);
|
||||||
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
|
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
|
||||||
|
test_device!(slice_scatter, slice_scatter_cpu, slice_scatter_gpu);
|
||||||
test_device!(randn, randn_cpu, randn_gpu);
|
test_device!(randn, randn_cpu, randn_gpu);
|
||||||
test_device!(clamp, clamp_cpu, clamp_gpu);
|
test_device!(clamp, clamp_cpu, clamp_gpu);
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user