Add a yolo-v3 example. (#528)

* Add a couple functions required for yolo.

* Add the yolo-v3 example.

* Add minimum and maximum.

* Use the newly introduced maximum.

* Cuda support for min/max + add some testing.

* Allow for more tests to work with accelerate.

* Fix a typo.
This commit is contained in:
Laurent Mazare
2023-08-20 18:19:37 +01:00
committed by GitHub
parent e3d2786ffb
commit a1812f934f
24 changed files with 1497 additions and 8 deletions

View File

@ -123,6 +123,42 @@ mod ffi {
_: c_long,
_: c_ulong,
);
pub fn vDSP_vminD(
_: *const c_double,
_: c_long,
_: *const c_double,
_: c_long,
_: *mut c_double,
_: c_long,
_: c_ulong,
);
pub fn vDSP_vmin(
_: *const c_float,
_: c_long,
_: *const c_float,
_: c_long,
_: *mut c_float,
_: c_long,
_: c_ulong,
);
pub fn vDSP_vmaxD(
_: *const c_double,
_: c_long,
_: *const c_double,
_: c_long,
_: *mut c_double,
_: c_long,
_: c_ulong,
);
pub fn vDSP_vmax(
_: *const c_float,
_: c_long,
_: *const c_float,
_: c_long,
_: *mut c_float,
_: c_long,
_: c_ulong,
);
}
}
@ -348,3 +384,7 @@ binary_op!(vs_mul, f32, vDSP_vmul);
binary_op!(vd_mul, f64, vDSP_vmulD);
binary_op!(vs_div, f32, vDSP_vdiv);
binary_op!(vd_div, f64, vDSP_vdivD);
binary_op!(vs_max, f32, vDSP_vmax);
binary_op!(vd_max, f64, vDSP_vmaxD);
binary_op!(vs_min, f32, vDSP_vmin);
binary_op!(vd_min, f64, vDSP_vminD);

View File

@ -162,6 +162,16 @@ impl Tensor {
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.sub(&rhs_grad)?;
}
Op::Binary(lhs, rhs, BinaryOp::Minimum)
| Op::Binary(lhs, rhs, BinaryOp::Maximum) => {
let lhs_grad = node.eq(lhs)?.to_dtype(grad.dtype())?.mul(&grad)?;
let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
let rhs_grad = node.eq(rhs)?.to_dtype(grad.dtype())?.mul(&grad)?;
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.sub(&rhs_grad)?;
}
Op::WhereCond(pred, t, f) => {
let zeros = grad.zeros_like()?;
let t_sum_grad = grads.or_insert(t)?;

View File

@ -25,6 +25,10 @@ mod ffi {
pub fn vdMul(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
pub fn vsDiv(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
pub fn vdDiv(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
pub fn vsFmax(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
pub fn vdFmax(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
pub fn vsFmin(n: c_int, a: *const c_float, b: *const c_float, y: *mut c_float);
pub fn vdFmin(n: c_int, a: *const c_double, b: *const c_double, y: *mut c_double);
pub fn sgemm_(
transa: *const c_char,
@ -376,3 +380,7 @@ binary_op!(vs_mul, f32, vsMul);
binary_op!(vd_mul, f64, vdMul);
binary_op!(vs_div, f32, vsDiv);
binary_op!(vd_div, f64, vdDiv);
binary_op!(vs_max, f32, vsFmax);
binary_op!(vd_max, f64, vdFmax);
binary_op!(vs_min, f32, vsFmin);
binary_op!(vd_min, f64, vdFmin);

View File

@ -40,6 +40,8 @@ pub enum BinaryOp {
Mul,
Sub,
Div,
Maximum,
Minimum,
}
// Unary ops with no argument
@ -291,6 +293,8 @@ pub(crate) struct Add;
pub(crate) struct Div;
pub(crate) struct Mul;
pub(crate) struct Sub;
pub(crate) struct Maximum;
pub(crate) struct Minimum;
pub(crate) struct Exp;
pub(crate) struct Log;
pub(crate) struct Sin;
@ -371,6 +375,20 @@ bin_op!(Add, "add", |v1, v2| v1 + v2, vs_add, vd_add);
bin_op!(Sub, "sub", |v1, v2| v1 - v2, vs_sub, vd_sub);
bin_op!(Mul, "mul", |v1, v2| v1 * v2, vs_mul, vd_mul);
bin_op!(Div, "div", |v1, v2| v1 / v2, vs_div, vd_div);
bin_op!(
Minimum,
"minimum",
|v1, v2| if v1 > v2 { v2 } else { v1 },
vs_min,
vd_min
);
bin_op!(
Maximum,
"maximum",
|v1, v2| if v1 < v2 { v2 } else { v1 },
vs_max,
vd_max
);
macro_rules! unary_op {
($op: ident, $name: literal, $a: ident, $e: expr) => {

View File

@ -444,10 +444,14 @@ impl Tensor {
binary_op!(mul, Mul);
binary_op!(sub, Sub);
binary_op!(div, Div);
binary_op!(maximum, Maximum);
binary_op!(minimum, Minimum);
broadcast_binary_op!(broadcast_add, add);
broadcast_binary_op!(broadcast_mul, mul);
broadcast_binary_op!(broadcast_sub, sub);
broadcast_binary_op!(broadcast_div, div);
broadcast_binary_op!(broadcast_maximum, maximum);
broadcast_binary_op!(broadcast_minimum, minimum);
unary_op!(recip, Recip);
unary_op!(neg, Neg);