[ONNX] Support a couple more ops. (#1284)

* Support the shape op in ONNX.

* Share the axis normalization bits.

* Add some limited support for gather.

* Unsqueeze.

* Comparison with broadcasting.

* Add Not + handle i32.
This commit is contained in:
Laurent Mazare
2023-11-06 22:44:58 +01:00
committed by GitHub
parent 5a363dbc26
commit a773a4b22b
4 changed files with 137 additions and 27 deletions

View File

@ -477,6 +477,12 @@ impl Tensor {
broadcast_binary_op!(broadcast_div, div);
broadcast_binary_op!(broadcast_maximum, maximum);
broadcast_binary_op!(broadcast_minimum, minimum);
broadcast_binary_op!(broadcast_eq, eq);
broadcast_binary_op!(broadcast_ne, ne);
broadcast_binary_op!(broadcast_lt, lt);
broadcast_binary_op!(broadcast_le, le);
broadcast_binary_op!(broadcast_gt, gt);
broadcast_binary_op!(broadcast_ge, ge);
unary_op!(recip, Recip);
unary_op!(neg, Neg);
@ -2406,6 +2412,23 @@ impl Tensor {
) -> Result<Self> {
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
}
/// Normalize a 'relative' axis value: positive values are kept, negative
/// values means counting the dimensions from the back.
pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
let rank = self.rank() as i64;
if rank <= axis {
crate::bail!("axis {axis} is too large, tensor rank {rank}")
} else if 0 <= axis {
Ok(axis as usize)
} else {
let naxis = rank + axis;
if naxis < 0 {
crate::bail!("axis {axis} is too small, tensor rank {rank}")
}
Ok(naxis as usize)
}
}
}
macro_rules! bin_trait {