Use broadcasted scalars for const tensors.

This commit is contained in:
laurent
2023-06-29 11:56:40 +01:00
parent 3872dc4751
commit 2741b39ad3
5 changed files with 12 additions and 14 deletions

View File

@ -20,7 +20,6 @@ fn matmul_grad(device: &Device) -> Result<()> {
let x = Tensor::var_from_slice(&data, (2, 2, 3), device)?;
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
let y = Tensor::var_from_slice(&data, (2, 3, 2), device)?;
let c = x.matmul(&y)?;
let grads = c.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;