mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add a yolo-v3 example. (#528)
* Add a couple functions required for yolo. * Add the yolo-v3 example. * Add minimum and maximum. * Use the newly introduced maximum. * Cuda support for min/max + add some testing. * Allow for more tests to work with accelerate. * Fix a typo.
This commit is contained in:
@ -162,6 +162,16 @@ impl Tensor {
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.sub(&rhs_grad)?;
|
||||
}
|
||||
Op::Binary(lhs, rhs, BinaryOp::Minimum)
|
||||
| Op::Binary(lhs, rhs, BinaryOp::Maximum) => {
|
||||
let lhs_grad = node.eq(lhs)?.to_dtype(grad.dtype())?.mul(&grad)?;
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
|
||||
|
||||
let rhs_grad = node.eq(rhs)?.to_dtype(grad.dtype())?.mul(&grad)?;
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.sub(&rhs_grad)?;
|
||||
}
|
||||
Op::WhereCond(pred, t, f) => {
|
||||
let zeros = grad.zeros_like()?;
|
||||
let t_sum_grad = grads.or_insert(t)?;
|
||||
|
Reference in New Issue
Block a user