onnx: ReduceMin/Max Ops (#2563)

* Stella_en_1.5B_v5

* Separated  creation. This is a critical step for numerical accuracy and would be documented in the readme

* EmbedDim would require clone and copy

* WIP: example

* Examples added

* a litte more in README

* WIP: ONNX Reduce-max ops

* WIP: tests for ReduceMin

* Reduce min/ max v18+

* Reformatting tests for better review readability

* Error on empty set, backward compatibility (13 and below) with 'axes'
This commit is contained in:
Anubhab Bandyopadhyay
2024-10-15 14:04:07 +05:30
committed by GitHub
parent 3d1dc06cdb
commit a01aa89799
2 changed files with 1211 additions and 1 deletions

View File

@ -2,7 +2,7 @@ use crate::onnx::attribute_proto::AttributeType;
use crate::onnx::tensor_proto::DataType;
use crate::onnx::{self, GraphProto};
use candle::{bail, DType, Device, Result, Tensor};
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
pub type Value = Tensor;
@ -1189,6 +1189,92 @@ fn simple_eval_(
}
values.insert(node.output[0].clone(), out);
}
// https://onnx.ai/onnx/operators/onnx__ReduceMax.html#reducemax
"ReduceMax" => {
let input = get(&node.input[0])?;
let axes = get_opt(1);
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1) == 1;
let axes = if let Some(Ok(axes)) = axes {
// Satisfies version 18+
axes.to_vec1::<i64>().ok()
} else if let Ok(Some(axes)) = get_attr_opt::<[i64]>(node, "axes") {
// Backward compatiblity with version 13 and below
Some(axes.to_vec())
} else {
None
};
let axes = if let Some(axes) = axes {
let rank = input.rank();
let mut axes_set = HashSet::new();
let mut axes = axes
.iter()
.map(|a| {
let axis = if *a < 0 {
(rank as i64 + *a) as usize
} else {
*a as usize
};
axes_set.insert(axis);
axis
})
.collect::<Vec<_>>();
if axes_set.len() < axes.len() {
bail!("Duplicate value in 'axes'");
}
if axes.len() > 1 {
axes.sort();
}
Some(axes)
} else {
None
};
// TODO: Handle empty set
// Definition:
// "Reduction over an empty set of values yields minus infinity (if supported by the datatype) or the minimum value of the data type otherwise"
// For now, this will throw an error
if input.elem_count() == 0 {
bail!("reduction over zero-size tensor not supported");
}
let output = if let Some(axes) = axes {
let mut result = input.clone();
for &axis in axes.iter().rev() {
result = if keepdims {
result.max_keepdim(axis)?
} else {
result.max(axis)?
}
}
result
} else {
// If `axes` is empty and `noop_with_empty_axes` is set to `true (1)`
// ""input tensor will not be reduced,and the output tensor would be equivalent to input tensor.""
if get_attr_opt::<i64>(node, "noop_with_empty_axes")?.copied() == Some(1) {
input.clone()
} else {
let mut result = input.flatten_all()?;
if keepdims {
result = result.max_keepdim(0)?;
// If keepdims is true, reshape to match input dimensions
let shape = vec![1; input.rank()];
result.reshape(shape)?
} else {
result.max(0)?
}
}
};
values.insert(node.output[0].clone(), output);
}
// https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13
// TODO: This version is only compatible with ReduceMean V13 and below.
"ReduceMean" => {
@ -1212,6 +1298,92 @@ fn simple_eval_(
};
values.insert(node.output[0].clone(), output);
}
// https://onnx.ai/onnx/operators/onnx__ReduceMin.html#reducemin
"ReduceMin" => {
let input = get(&node.input[0])?;
let axes = get_opt(1);
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1) == 1;
let axes = if let Some(Ok(axes)) = axes {
// Satisfies version 18+
axes.to_vec1::<i64>().ok()
} else if let Ok(Some(axes)) = get_attr_opt::<[i64]>(node, "axes") {
// Backward compatiblity with version 13 and below
Some(axes.to_vec())
} else {
None
};
let axes = if let Some(axes) = axes {
let rank = input.rank();
let mut axes_set = HashSet::new();
let mut axes = axes
.iter()
.map(|a| {
let axis = if *a < 0 {
(rank as i64 + *a) as usize
} else {
*a as usize
};
axes_set.insert(axis);
axis
})
.collect::<Vec<_>>();
if axes_set.len() < axes.len() {
bail!("Duplicate value in 'axes'");
}
if axes.len() > 1 {
axes.sort();
}
Some(axes)
} else {
None
};
// TODO: Handle empty set
// Definition:
// "Reduction over an empty set of values yields positive infinity (if supported by the datatype) or the max value of the data type otherwise"
// For now, this will throw an error
if input.elem_count() == 0 {
bail!("reduction over zero-size tensor not supported");
}
let output = if let Some(axes) = axes {
let mut result = input.clone();
for &axis in axes.iter().rev() {
result = if keepdims {
result.min_keepdim(axis)?
} else {
result.min(axis)?
}
}
result
} else {
// If `axes` is empty and `noop_with_empty_axes` is set to `true (1)`
// ""input tensor will not be reduced,and the output tensor would be equivalent to input tensor.""
if get_attr_opt::<i64>(node, "noop_with_empty_axes")?.copied() == Some(1) {
input.clone()
} else {
let mut result = input.flatten_all()?;
if keepdims {
result = result.min_keepdim(0)?;
// If keepdims is true, reshape to match input dimensions
let shape = vec![1; input.rank()];
result.reshape(shape)?
} else {
result.min(0)?
}
}
};
values.insert(node.output[0].clone(), output);
}
//https://github.com/onnx/onnx/blob/main/docs/Operators.md#Split
// Version 18 impl
"Split" => {

File diff suppressed because it is too large Load Diff