mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Add ReduceMean onnx operation (#2049)
* Add ReduceMean onnx operation * Format code with rustfmt
This commit is contained in:
@ -2,7 +2,7 @@ use crate::onnx;
|
||||
use crate::onnx::attribute_proto::AttributeType;
|
||||
use crate::onnx::tensor_proto::DataType;
|
||||
use candle::{bail, DType, Device, Result, Tensor};
|
||||
use std::collections::HashMap;
|
||||
use std::{collections::HashMap, usize};
|
||||
|
||||
pub type Value = Tensor;
|
||||
|
||||
@ -797,6 +797,29 @@ pub fn simple_eval(
|
||||
let input = get(&node.input[0])?;
|
||||
values.insert(node.output[0].clone(), input.clone());
|
||||
}
|
||||
// https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13
|
||||
// TODO: This version is only compatible with ReduceMean V13 and below.
|
||||
"ReduceMean" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let axes = get_attr_opt::<[i64]>(node, "axes")?;
|
||||
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1);
|
||||
|
||||
let n_dims = input.dims().len();
|
||||
|
||||
let axes: Vec<usize> = if let Some(axes) = axes {
|
||||
axes.iter()
|
||||
.map(|e| (if e < &0 { (n_dims as i64) + *e } else { *e }) as usize)
|
||||
.collect()
|
||||
} else {
|
||||
(0..n_dims).collect()
|
||||
};
|
||||
let output = if keepdims == 1 {
|
||||
input.mean_keepdim(axes)?
|
||||
} else {
|
||||
input.mean(axes)?
|
||||
};
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||
}
|
||||
}
|
||||
|
@ -1462,3 +1462,180 @@ fn test_relu_operation() -> Result<()> {
|
||||
|
||||
// "Cast"
|
||||
// #[test]
|
||||
|
||||
// "ReduceMean"
|
||||
#[test]
|
||||
fn test_reduce_mean() -> Result<()> {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-120 default_axes_keepdims
|
||||
test(
|
||||
&[
|
||||
[[5., 1.], [20., 2.]],
|
||||
[[30., 1.], [40., 2.]],
|
||||
[[55., 1.], [60., 2.]],
|
||||
],
|
||||
None,
|
||||
1,
|
||||
&[[[18.25]]],
|
||||
)?;
|
||||
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-120 do_no_keepdims
|
||||
test(
|
||||
&[
|
||||
[[5., 1.], [20., 2.]],
|
||||
[[30., 1.], [40., 2.]],
|
||||
[[55., 1.], [60., 2.]],
|
||||
],
|
||||
Some(vec![1]),
|
||||
0,
|
||||
&[[12.5, 1.5], [35.0, 1.5], [57.5, 1.5]],
|
||||
)?;
|
||||
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-120 keepdims
|
||||
test(
|
||||
&[
|
||||
[[5., 1.], [20., 2.]],
|
||||
[[30., 1.], [40., 2.]],
|
||||
[[55., 1.], [60., 2.]],
|
||||
],
|
||||
Some(vec![1]),
|
||||
1,
|
||||
&[[[12.5, 1.5]], [[35.0, 1.5]], [[57.5, 1.5]]],
|
||||
)?;
|
||||
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-120 negative_axes_keepdims
|
||||
test(
|
||||
&[
|
||||
[[5., 1.], [20., 2.]],
|
||||
[[30., 1.], [40., 2.]],
|
||||
[[55., 1.], [60., 2.]],
|
||||
],
|
||||
Some(vec![-2]),
|
||||
1,
|
||||
&[[[12.5, 1.5]], [[35.0, 1.5]], [[57.5, 1.5]]],
|
||||
)?;
|
||||
|
||||
// All the test data below was generated based on numpy's np.mean
|
||||
test(
|
||||
&[
|
||||
[[5., 1.], [20., 2.]],
|
||||
[[30., 1.], [40., 2.]],
|
||||
[[55., 1.], [60., 2.]],
|
||||
],
|
||||
Some(vec![1, 2]),
|
||||
0,
|
||||
&[7.0, 18.25, 29.5],
|
||||
)?;
|
||||
|
||||
test(
|
||||
&[
|
||||
[[5., 1.], [20., 2.]],
|
||||
[[30., 1.], [40., 2.]],
|
||||
[[55., 1.], [60., 2.]],
|
||||
],
|
||||
Some(vec![1, 2]),
|
||||
1,
|
||||
&[[[7.0]], [[18.25]], [[29.5]]],
|
||||
)?;
|
||||
|
||||
test(&[1., 2., 3.], None, 1, &[2.0])?;
|
||||
|
||||
fn test(
|
||||
data: impl NdArray,
|
||||
axes: Option<Vec<i64>>,
|
||||
keepdims: i64,
|
||||
expected: impl NdArray,
|
||||
) -> Result<()> {
|
||||
let has_axes = axes.is_some();
|
||||
|
||||
let att_axes = AttributeProto {
|
||||
name: "axes".to_string(),
|
||||
ref_attr_name: "axes".to_string(),
|
||||
i: 0,
|
||||
doc_string: "axes".to_string(),
|
||||
r#type: 7,
|
||||
f: 0.0,
|
||||
s: vec![],
|
||||
t: None,
|
||||
g: None,
|
||||
sparse_tensor: None,
|
||||
tp: None,
|
||||
floats: vec![],
|
||||
ints: axes.unwrap_or_default(),
|
||||
strings: vec![],
|
||||
tensors: vec![],
|
||||
graphs: vec![],
|
||||
sparse_tensors: vec![],
|
||||
type_protos: vec![],
|
||||
};
|
||||
|
||||
let att_keepdims = AttributeProto {
|
||||
name: "keepdims".to_string(),
|
||||
ref_attr_name: "keepdims".to_string(),
|
||||
i: keepdims,
|
||||
doc_string: "keepdims".to_string(),
|
||||
r#type: 2,
|
||||
f: 0.0,
|
||||
s: vec![],
|
||||
t: None,
|
||||
g: None,
|
||||
sparse_tensor: None,
|
||||
tp: None,
|
||||
floats: vec![],
|
||||
ints: vec![],
|
||||
strings: vec![],
|
||||
tensors: vec![],
|
||||
graphs: vec![],
|
||||
sparse_tensors: vec![],
|
||||
type_protos: vec![],
|
||||
};
|
||||
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "ReduceMean".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: if has_axes {
|
||||
vec![att_axes, att_keepdims]
|
||||
} else {
|
||||
vec![att_keepdims]
|
||||
},
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let expected = Tensor::new(expected, &Device::Cpu)?;
|
||||
match expected.dims().len() {
|
||||
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
|
||||
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
|
||||
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
|
||||
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user