Add ReduceMean onnx operation (#2049)

* Add ReduceMean onnx operation

* Format code with rustfmt
This commit is contained in:
Gabriel
2024-04-13 11:00:25 +02:00
committed by GitHub
parent 26cbbf8d84
commit e6d412b156
2 changed files with 201 additions and 1 deletions

View File

@ -2,7 +2,7 @@ use crate::onnx;
use crate::onnx::attribute_proto::AttributeType;
use crate::onnx::tensor_proto::DataType;
use candle::{bail, DType, Device, Result, Tensor};
use std::collections::HashMap;
use std::{collections::HashMap, usize};
pub type Value = Tensor;
@ -797,6 +797,29 @@ pub fn simple_eval(
let input = get(&node.input[0])?;
values.insert(node.output[0].clone(), input.clone());
}
// https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13
// TODO: This version is only compatible with ReduceMean V13 and below.
"ReduceMean" => {
let input = get(&node.input[0])?;
let axes = get_attr_opt::<[i64]>(node, "axes")?;
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1);
let n_dims = input.dims().len();
let axes: Vec<usize> = if let Some(axes) = axes {
axes.iter()
.map(|e| (if e < &0 { (n_dims as i64) + *e } else { *e }) as usize)
.collect()
} else {
(0..n_dims).collect()
};
let output = if keepdims == 1 {
input.mean_keepdim(axes)?
} else {
input.mean(axes)?
};
values.insert(node.output[0].clone(), output);
}
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
}
}