Tensor mutability (#154)

* Working towards tensor mutability.

* Use a ref-cell to provide tensor mutability.
This commit is contained in:
Laurent Mazare
2023-07-13 11:04:40 +01:00
committed by GitHub
parent a3663ce2f2
commit 50b0946a2d
14 changed files with 124 additions and 88 deletions

View File

@ -222,7 +222,7 @@ impl Tensor {
} else {
let mut dims = arg_dims.to_vec();
dims[dim] = start_idx;
Some(Tensor::zeros(dims, grad.dtype(), &grad.device())?)
Some(Tensor::zeros(dims, grad.dtype(), grad.device())?)
};
let right_pad = arg_dims[dim] - start_idx - len;
let right_pad = if right_pad == 0 {
@ -230,7 +230,7 @@ impl Tensor {
} else {
let mut dims = arg_dims.to_vec();
dims[dim] = right_pad;
Some(Tensor::zeros(dims, grad.dtype(), &grad.device())?)
Some(Tensor::zeros(dims, grad.dtype(), grad.device())?)
};
let arg_grad = match (left_pad, right_pad) {
(None, None) => grad,
@ -264,7 +264,7 @@ impl Tensor {
}
Op::ToDevice(arg) => {
let sum_grad = grads.or_insert(arg)?;
let arg_grad = grad.to_device(&sum_grad.device())?;
let arg_grad = grad.to_device(sum_grad.device())?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Transpose(arg, dim1, dim2) => {