mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Tensor mutability (#154)
* Working towards tensor mutability. * Use a ref-cell to provide tensor mutability.
This commit is contained in:
@ -222,7 +222,7 @@ impl Tensor {
|
||||
} else {
|
||||
let mut dims = arg_dims.to_vec();
|
||||
dims[dim] = start_idx;
|
||||
Some(Tensor::zeros(dims, grad.dtype(), &grad.device())?)
|
||||
Some(Tensor::zeros(dims, grad.dtype(), grad.device())?)
|
||||
};
|
||||
let right_pad = arg_dims[dim] - start_idx - len;
|
||||
let right_pad = if right_pad == 0 {
|
||||
@ -230,7 +230,7 @@ impl Tensor {
|
||||
} else {
|
||||
let mut dims = arg_dims.to_vec();
|
||||
dims[dim] = right_pad;
|
||||
Some(Tensor::zeros(dims, grad.dtype(), &grad.device())?)
|
||||
Some(Tensor::zeros(dims, grad.dtype(), grad.device())?)
|
||||
};
|
||||
let arg_grad = match (left_pad, right_pad) {
|
||||
(None, None) => grad,
|
||||
@ -264,7 +264,7 @@ impl Tensor {
|
||||
}
|
||||
Op::ToDevice(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
let arg_grad = grad.to_device(&sum_grad.device())?;
|
||||
let arg_grad = grad.to_device(sum_grad.device())?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Transpose(arg, dim1, dim2) => {
|
||||
|
Reference in New Issue
Block a user