mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
[ONNX] Support a couple more ops. (#1284)
* Support the shape op in ONNX. * Share the axis normalization bits. * Add some limited support for gather. * Unsqueeze. * Comparison with broadcasting. * Add Not + handle i32.
This commit is contained in:
@ -477,6 +477,12 @@ impl Tensor {
|
||||
broadcast_binary_op!(broadcast_div, div);
|
||||
broadcast_binary_op!(broadcast_maximum, maximum);
|
||||
broadcast_binary_op!(broadcast_minimum, minimum);
|
||||
broadcast_binary_op!(broadcast_eq, eq);
|
||||
broadcast_binary_op!(broadcast_ne, ne);
|
||||
broadcast_binary_op!(broadcast_lt, lt);
|
||||
broadcast_binary_op!(broadcast_le, le);
|
||||
broadcast_binary_op!(broadcast_gt, gt);
|
||||
broadcast_binary_op!(broadcast_ge, ge);
|
||||
|
||||
unary_op!(recip, Recip);
|
||||
unary_op!(neg, Neg);
|
||||
@ -2406,6 +2412,23 @@ impl Tensor {
|
||||
) -> Result<Self> {
|
||||
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
|
||||
}
|
||||
|
||||
/// Normalize a 'relative' axis value: positive values are kept, negative
|
||||
/// values means counting the dimensions from the back.
|
||||
pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
|
||||
let rank = self.rank() as i64;
|
||||
if rank <= axis {
|
||||
crate::bail!("axis {axis} is too large, tensor rank {rank}")
|
||||
} else if 0 <= axis {
|
||||
Ok(axis as usize)
|
||||
} else {
|
||||
let naxis = rank + axis;
|
||||
if naxis < 0 {
|
||||
crate::bail!("axis {axis} is too small, tensor rank {rank}")
|
||||
}
|
||||
Ok(naxis as usize)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! bin_trait {
|
||||
|
Reference in New Issue
Block a user