mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Add ReduceMean onnx operation (#2049)
* Add ReduceMean onnx operation * Format code with rustfmt
This commit is contained in:
@ -2,7 +2,7 @@ use crate::onnx;
|
||||
use crate::onnx::attribute_proto::AttributeType;
|
||||
use crate::onnx::tensor_proto::DataType;
|
||||
use candle::{bail, DType, Device, Result, Tensor};
|
||||
use std::collections::HashMap;
|
||||
use std::{collections::HashMap, usize};
|
||||
|
||||
pub type Value = Tensor;
|
||||
|
||||
@ -797,6 +797,29 @@ pub fn simple_eval(
|
||||
let input = get(&node.input[0])?;
|
||||
values.insert(node.output[0].clone(), input.clone());
|
||||
}
|
||||
// https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13
|
||||
// TODO: This version is only compatible with ReduceMean V13 and below.
|
||||
"ReduceMean" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let axes = get_attr_opt::<[i64]>(node, "axes")?;
|
||||
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1);
|
||||
|
||||
let n_dims = input.dims().len();
|
||||
|
||||
let axes: Vec<usize> = if let Some(axes) = axes {
|
||||
axes.iter()
|
||||
.map(|e| (if e < &0 { (n_dims as i64) + *e } else { *e }) as usize)
|
||||
.collect()
|
||||
} else {
|
||||
(0..n_dims).collect()
|
||||
};
|
||||
let output = if keepdims == 1 {
|
||||
input.mean_keepdim(axes)?
|
||||
} else {
|
||||
input.mean(axes)?
|
||||
};
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user