mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Add more gradient tests + bugfixes. (#211)
* Add more gradient tests + bugfixes. * More tests and fixes. * More tests.
This commit is contained in:
@ -146,7 +146,7 @@ impl Tensor {
|
||||
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
|
||||
let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.sub(&rhs_grad)?;
|
||||
}
|
||||
Op::WhereCond(pred, t, f) => {
|
||||
let zeros = grad.zeros_like()?;
|
||||
@ -162,6 +162,7 @@ impl Tensor {
|
||||
let dim = *dim;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
// TODO: This is very very very inefficient, have some dedicated kernel for this.
|
||||
// https://pytorch.org/docs/stable/generated/torch.Tensor.index_add.html
|
||||
let indexes = indexes.to_vec1::<u32>()?;
|
||||
for (dst_index, src_index) in indexes.iter().enumerate() {
|
||||
let src_index = *src_index as usize;
|
||||
@ -318,7 +319,7 @@ impl Tensor {
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Unary(arg, UnaryOp::Sqrt) => {
|
||||
let arg_grad = grad.div(arg)?.affine(0.5, 0.)?;
|
||||
let arg_grad = grad.div(node)?.affine(0.5, 0.)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
|
Reference in New Issue
Block a user