mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
onnx: ReduceMin/Max Ops (#2563)
* Stella_en_1.5B_v5 * Separated creation. This is a critical step for numerical accuracy and would be documented in the readme * EmbedDim would require clone and copy * WIP: example * Examples added * a litte more in README * WIP: ONNX Reduce-max ops * WIP: tests for ReduceMin * Reduce min/ max v18+ * Reformatting tests for better review readability * Error on empty set, backward compatibility (13 and below) with 'axes'
This commit is contained in:

committed by
GitHub

parent
3d1dc06cdb
commit
a01aa89799
@ -2,7 +2,7 @@ use crate::onnx::attribute_proto::AttributeType;
|
||||
use crate::onnx::tensor_proto::DataType;
|
||||
use crate::onnx::{self, GraphProto};
|
||||
use candle::{bail, DType, Device, Result, Tensor};
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
pub type Value = Tensor;
|
||||
|
||||
@ -1189,6 +1189,92 @@ fn simple_eval_(
|
||||
}
|
||||
values.insert(node.output[0].clone(), out);
|
||||
}
|
||||
// https://onnx.ai/onnx/operators/onnx__ReduceMax.html#reducemax
|
||||
"ReduceMax" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let axes = get_opt(1);
|
||||
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1) == 1;
|
||||
|
||||
let axes = if let Some(Ok(axes)) = axes {
|
||||
// Satisfies version 18+
|
||||
axes.to_vec1::<i64>().ok()
|
||||
} else if let Ok(Some(axes)) = get_attr_opt::<[i64]>(node, "axes") {
|
||||
// Backward compatiblity with version 13 and below
|
||||
Some(axes.to_vec())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let axes = if let Some(axes) = axes {
|
||||
let rank = input.rank();
|
||||
let mut axes_set = HashSet::new();
|
||||
|
||||
let mut axes = axes
|
||||
.iter()
|
||||
.map(|a| {
|
||||
let axis = if *a < 0 {
|
||||
(rank as i64 + *a) as usize
|
||||
} else {
|
||||
*a as usize
|
||||
};
|
||||
|
||||
axes_set.insert(axis);
|
||||
axis
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if axes_set.len() < axes.len() {
|
||||
bail!("Duplicate value in 'axes'");
|
||||
}
|
||||
|
||||
if axes.len() > 1 {
|
||||
axes.sort();
|
||||
}
|
||||
|
||||
Some(axes)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// TODO: Handle empty set
|
||||
// Definition:
|
||||
// "Reduction over an empty set of values yields minus infinity (if supported by the datatype) or the minimum value of the data type otherwise"
|
||||
// For now, this will throw an error
|
||||
if input.elem_count() == 0 {
|
||||
bail!("reduction over zero-size tensor not supported");
|
||||
}
|
||||
|
||||
let output = if let Some(axes) = axes {
|
||||
let mut result = input.clone();
|
||||
for &axis in axes.iter().rev() {
|
||||
result = if keepdims {
|
||||
result.max_keepdim(axis)?
|
||||
} else {
|
||||
result.max(axis)?
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
} else {
|
||||
// If `axes` is empty and `noop_with_empty_axes` is set to `true (1)`
|
||||
// ""input tensor will not be reduced,and the output tensor would be equivalent to input tensor.""
|
||||
if get_attr_opt::<i64>(node, "noop_with_empty_axes")?.copied() == Some(1) {
|
||||
input.clone()
|
||||
} else {
|
||||
let mut result = input.flatten_all()?;
|
||||
if keepdims {
|
||||
result = result.max_keepdim(0)?;
|
||||
// If keepdims is true, reshape to match input dimensions
|
||||
let shape = vec![1; input.rank()];
|
||||
result.reshape(shape)?
|
||||
} else {
|
||||
result.max(0)?
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
// https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13
|
||||
// TODO: This version is only compatible with ReduceMean V13 and below.
|
||||
"ReduceMean" => {
|
||||
@ -1212,6 +1298,92 @@ fn simple_eval_(
|
||||
};
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
// https://onnx.ai/onnx/operators/onnx__ReduceMin.html#reducemin
|
||||
"ReduceMin" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let axes = get_opt(1);
|
||||
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1) == 1;
|
||||
|
||||
let axes = if let Some(Ok(axes)) = axes {
|
||||
// Satisfies version 18+
|
||||
axes.to_vec1::<i64>().ok()
|
||||
} else if let Ok(Some(axes)) = get_attr_opt::<[i64]>(node, "axes") {
|
||||
// Backward compatiblity with version 13 and below
|
||||
Some(axes.to_vec())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let axes = if let Some(axes) = axes {
|
||||
let rank = input.rank();
|
||||
let mut axes_set = HashSet::new();
|
||||
|
||||
let mut axes = axes
|
||||
.iter()
|
||||
.map(|a| {
|
||||
let axis = if *a < 0 {
|
||||
(rank as i64 + *a) as usize
|
||||
} else {
|
||||
*a as usize
|
||||
};
|
||||
|
||||
axes_set.insert(axis);
|
||||
axis
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if axes_set.len() < axes.len() {
|
||||
bail!("Duplicate value in 'axes'");
|
||||
}
|
||||
|
||||
if axes.len() > 1 {
|
||||
axes.sort();
|
||||
}
|
||||
|
||||
Some(axes)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// TODO: Handle empty set
|
||||
// Definition:
|
||||
// "Reduction over an empty set of values yields positive infinity (if supported by the datatype) or the max value of the data type otherwise"
|
||||
// For now, this will throw an error
|
||||
if input.elem_count() == 0 {
|
||||
bail!("reduction over zero-size tensor not supported");
|
||||
}
|
||||
|
||||
let output = if let Some(axes) = axes {
|
||||
let mut result = input.clone();
|
||||
for &axis in axes.iter().rev() {
|
||||
result = if keepdims {
|
||||
result.min_keepdim(axis)?
|
||||
} else {
|
||||
result.min(axis)?
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
} else {
|
||||
// If `axes` is empty and `noop_with_empty_axes` is set to `true (1)`
|
||||
// ""input tensor will not be reduced,and the output tensor would be equivalent to input tensor.""
|
||||
if get_attr_opt::<i64>(node, "noop_with_empty_axes")?.copied() == Some(1) {
|
||||
input.clone()
|
||||
} else {
|
||||
let mut result = input.flatten_all()?;
|
||||
if keepdims {
|
||||
result = result.min_keepdim(0)?;
|
||||
// If keepdims is true, reshape to match input dimensions
|
||||
let shape = vec![1; input.rank()];
|
||||
result.reshape(shape)?
|
||||
} else {
|
||||
result.min(0)?
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
//https://github.com/onnx/onnx/blob/main/docs/Operators.md#Split
|
||||
// Version 18 impl
|
||||
"Split" => {
|
||||
|
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user