mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 20:09:50 +00:00
Add slice-scatter. (#927)
* Add slice-scatter. * Add the op. * Make transpose be a no-op when the dimensions are identical. * Add the backprop. * And add some gradient test.
This commit is contained in:
@ -218,6 +218,22 @@ fn binary_grad(device: &Device) -> Result<()> {
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(y.to_vec1::<f32>()?, [3., 1., -4., -1.]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [1., 1., 1., 1.]);
|
||||
|
||||
let x_var = Var::new(&[3f32, 1., -4., -1., 5., 9.], device)?;
|
||||
let x = x_var.as_tensor();
|
||||
let y_var = Var::new(&[2f32, 7., 1.], device)?;
|
||||
let y = y_var.as_tensor();
|
||||
|
||||
let ss = x
|
||||
.reshape((2, 3))?
|
||||
.slice_scatter0(&y.reshape((1, 3))?, 1)?
|
||||
.sqr()?;
|
||||
let grads = ss.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
let grad_y = grads.get(y).context("no grad for y")?;
|
||||
assert_eq!(ss.to_vec2::<f32>()?, [[9., 1., 16.], [4., 49., 1.]]);
|
||||
assert_eq!(grad_x.to_vec1::<f32>()?, [6.0, 2.0, -8.0, 0.0, 0.0, 0.0]);
|
||||
assert_eq!(grad_y.to_vec1::<f32>()?, [4.0, 14.0, 2.0]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user