mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add slice-scatter. (#927)
* Add slice-scatter. * Add the op. * Make transpose be a no-op when the dimensions are identical. * Add the backprop. * And add some gradient test.
This commit is contained in:
@ -69,7 +69,8 @@ impl Tensor {
|
||||
| Op::Binary(lhs, rhs, _)
|
||||
| Op::Gather(lhs, rhs, _)
|
||||
| Op::IndexSelect(lhs, rhs, _)
|
||||
| Op::Matmul(lhs, rhs) => {
|
||||
| Op::Matmul(lhs, rhs)
|
||||
| Op::SliceScatter0(lhs, rhs, _) => {
|
||||
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
let (tg, nodes) = walk(rhs, nodes, already_seen);
|
||||
@ -270,6 +271,15 @@ impl Tensor {
|
||||
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "upsample-nearest2d",
|
||||
})?,
|
||||
Op::SliceScatter0(lhs, rhs, start_rhs) => {
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
let lhs_grad = grad.slice_scatter0(&rhs.zeros_like()?, *start_rhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?
|
||||
}
|
||||
Op::Gather(arg, indexes, dim) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
|
||||
|
@ -133,6 +133,7 @@ pub enum Op {
|
||||
Copy(Tensor),
|
||||
Broadcast(Tensor),
|
||||
Narrow(Tensor, usize, usize, usize),
|
||||
SliceScatter0(Tensor, Tensor, usize),
|
||||
Reshape(Tensor),
|
||||
ToDevice(Tensor),
|
||||
Transpose(Tensor, usize, usize),
|
||||
|
@ -1132,6 +1132,74 @@ impl Tensor {
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
/// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension.
|
||||
pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: usize, start: usize) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "slice-scatter")?;
|
||||
if dim == 0 {
|
||||
self.slice_scatter0(src, start)
|
||||
} else {
|
||||
// TODO: Maybe we want to add a more efficient implementation at some point.
|
||||
self.transpose(0, dim)?
|
||||
.slice_scatter0(&src.transpose(0, dim)?, start)?
|
||||
.transpose(0, dim)
|
||||
}
|
||||
}
|
||||
|
||||
/// Embeds the values of the `src` tensor into the `self` tensor on the first dimension.
|
||||
pub fn slice_scatter0(&self, src: &Self, start: usize) -> Result<Self> {
|
||||
if self.dtype() != src.dtype() {
|
||||
Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: self.dtype(),
|
||||
rhs: src.dtype(),
|
||||
op: "slice-scatter",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if self.device().location() != src.device.location() {
|
||||
Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: self.device().location(),
|
||||
rhs: src.device().location(),
|
||||
op: "slice-scatter",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if self.rank() != src.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: self.rank(),
|
||||
got: src.rank(),
|
||||
shape: src.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let shape_ok =
|
||||
self.dims()
|
||||
.iter()
|
||||
.zip(src.dims().iter())
|
||||
.enumerate()
|
||||
.all(|(dim_idx, (&d1, &d2))| {
|
||||
if 0 == dim_idx {
|
||||
d2 + start <= d1
|
||||
} else {
|
||||
d1 == d2
|
||||
}
|
||||
});
|
||||
if !shape_ok {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
op: "slice-scatter (self, src)",
|
||||
lhs: self.shape().clone(),
|
||||
rhs: src.shape().clone(),
|
||||
})?
|
||||
}
|
||||
let mut storage = self.device().zeros(self.shape(), self.dtype())?;
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
let offset = start * src.dims()[1..].iter().product::<usize>();
|
||||
src.storage()
|
||||
.copy_strided_src(&mut storage, offset, src.layout())?;
|
||||
let op = BackpropOp::new2(self, src, |t1, t2| Op::SliceScatter0(t1, t2, start));
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
/// Accumulate element from `source` at indexes `indexes` and add them to `self`.
|
||||
pub fn index_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "index-add")?;
|
||||
@ -1548,6 +1616,9 @@ impl Tensor {
|
||||
pub fn transpose<D1: Dim, D2: Dim>(&self, dim1: D1, dim2: D2) -> Result<Tensor> {
|
||||
let dim1 = dim1.to_index(self.shape(), "transpose")?;
|
||||
let dim2 = dim2.to_index(self.shape(), "transpose")?;
|
||||
if dim1 == dim2 {
|
||||
return Ok(self.clone());
|
||||
}
|
||||
let op = BackpropOp::new1(self, |t| Op::Transpose(t, dim1, dim2));
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
|
@ -218,6 +218,22 @@ fn binary_grad(device: &Device) -> Result<()> {
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [3., 1., -4., -1.]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [1., 1., 1., 1.]);
|
||||
|
||||
let x_var = Var::new(&[3f32, 1., -4., -1., 5., 9.], device)?;
|
||||
let x = x_var.as_tensor();
|
||||
let y_var = Var::new(&[2f32, 7., 1.], device)?;
|
||||
let y = y_var.as_tensor();
|
||||
|
||||
let ss = x
|
||||
.reshape((2, 3))?
|
||||
.slice_scatter0(&y.reshape((1, 3))?, 1)?
|
||||
.sqr()?;
|
||||
let grads = ss.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
let grad_y = grads.get(y).context("no grad for y")?;
|
||||
assert_eq!(ss.to_vec2::<f32>()?, [[9., 1., 16.], [4., 49., 1.]]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [6.0, 2.0, -8.0, 0.0, 0.0, 0.0]);
|
||||
assert_eq!(grad_y.to_vec1::<f32>()?, [4.0, 14.0, 2.0]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -674,6 +674,48 @@ fn index_add(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn slice_scatter(device: &Device) -> Result<()> {
|
||||
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
|
||||
assert_eq!(
|
||||
t.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[9.0, 10.0, 11.0]
|
||||
]
|
||||
);
|
||||
let src = Tensor::arange(100f32, 106f32, device)?.reshape((2, 3))?;
|
||||
assert_eq!(
|
||||
t.slice_scatter0(&src, 0)?.to_vec2::<f32>()?,
|
||||
&[
|
||||
[100.0, 101.0, 102.0],
|
||||
[103.0, 104.0, 105.0],
|
||||
[6.0, 7.0, 8.0],
|
||||
[9.0, 10.0, 11.0]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
t.slice_scatter0(&src, 1)?.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[100.0, 101.0, 102.0],
|
||||
[103.0, 104.0, 105.0],
|
||||
[9.0, 10.0, 11.0]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
t.slice_scatter0(&src, 2)?.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
[100.0, 101.0, 102.0],
|
||||
[103.0, 104.0, 105.0],
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn scatter_add(device: &Device) -> Result<()> {
|
||||
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
|
||||
assert_eq!(
|
||||
@ -946,6 +988,7 @@ test_device!(index_select, index_select_cpu, index_select_gpu);
|
||||
test_device!(index_add, index_add_cpu, index_add_gpu);
|
||||
test_device!(gather, gather_cpu, gather_gpu);
|
||||
test_device!(scatter_add, scatter_add_cpu, scatter_add_gpu);
|
||||
test_device!(slice_scatter, slice_scatter_cpu, slice_scatter_gpu);
|
||||
test_device!(randn, randn_cpu, randn_gpu);
|
||||
test_device!(clamp, clamp_cpu, clamp_gpu);
|
||||
|
||||
|
Reference in New Issue
Block a user