mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Merge pull request #26 from LaurentMazare/narrow-grad
Add the grad for narrow.
This commit is contained in:
@ -208,8 +208,31 @@ impl Tensor {
|
|||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.sub(&grad)?
|
*sum_grad = sum_grad.sub(&grad)?
|
||||||
}
|
}
|
||||||
Op::Narrow(_arg, _, _, _) => {
|
&Op::Narrow(ref arg, dim, start_idx, len) => {
|
||||||
return Err(Error::BackwardNotSupported { op: "narrow" })
|
let arg_dims = arg.dims();
|
||||||
|
let left_pad = if start_idx == 0 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
let mut dims = arg_dims.to_vec();
|
||||||
|
dims[dim] = start_idx;
|
||||||
|
Some(Tensor::zeros(dims, grad.dtype(), &grad.device())?)
|
||||||
|
};
|
||||||
|
let right_pad = arg_dims[dim] - start_idx - len;
|
||||||
|
let right_pad = if right_pad == 0 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
let mut dims = arg_dims.to_vec();
|
||||||
|
dims[dim] = right_pad;
|
||||||
|
Some(Tensor::zeros(dims, grad.dtype(), &grad.device())?)
|
||||||
|
};
|
||||||
|
let arg_grad = match (left_pad, right_pad) {
|
||||||
|
(None, None) => grad,
|
||||||
|
(Some(l), None) => Tensor::cat(&[&l, &grad], dim)?,
|
||||||
|
(None, Some(r)) => Tensor::cat(&[&grad, &r], dim)?,
|
||||||
|
(Some(l), Some(r)) => Tensor::cat(&[&l, &grad, &r], dim)?,
|
||||||
|
};
|
||||||
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
}
|
}
|
||||||
Op::Softmax(_arg, _) => {
|
Op::Softmax(_arg, _) => {
|
||||||
return Err(Error::BackwardNotSupported { op: "softmax" })
|
return Err(Error::BackwardNotSupported { op: "softmax" })
|
||||||
|
Reference in New Issue
Block a user