mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
implement if, and pad reflect mode (#2251)
* implement if, and pad reflect mode The intent of this change is to allow eval of the current silero_vad.onnx (v4). This onnx file uses 'If' and 'Pad' nodes, which had not been supported by simple_eval until now * Cleanup (fmt, clippy, minor test tweaks). --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
use crate::onnx;
|
||||
use crate::onnx::attribute_proto::AttributeType;
|
||||
use crate::onnx::tensor_proto::DataType;
|
||||
use crate::onnx::{self, GraphProto};
|
||||
use candle::{bail, DType, Device, Result, Tensor};
|
||||
use std::{collections::HashMap, usize};
|
||||
|
||||
@ -56,6 +56,15 @@ impl Attr for str {
|
||||
}
|
||||
}
|
||||
|
||||
impl Attr for GraphProto {
|
||||
const TYPE: AttributeType = AttributeType::Graph;
|
||||
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
|
||||
attr.g
|
||||
.as_ref()
|
||||
.ok_or_else(|| candle::Error::Msg("attribute does not contain graph".to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
impl AttrOwned for Tensor {
|
||||
const TYPE: AttributeType = AttributeType::Tensor;
|
||||
fn get(attr: &onnx::AttributeProto) -> Result<Self> {
|
||||
@ -214,13 +223,19 @@ pub fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
|
||||
// anymore.
|
||||
pub fn simple_eval(
|
||||
model: &onnx::ModelProto,
|
||||
inputs: HashMap<String, Value>,
|
||||
mut inputs: HashMap<String, Value>,
|
||||
) -> Result<HashMap<String, Value>> {
|
||||
let graph = match &model.graph {
|
||||
None => bail!("no graph defined in proto"),
|
||||
Some(graph) => graph,
|
||||
};
|
||||
let mut values = inputs;
|
||||
simple_eval_(graph, &mut inputs)
|
||||
}
|
||||
|
||||
fn simple_eval_(
|
||||
graph: &onnx::GraphProto,
|
||||
values: &mut HashMap<String, Value>,
|
||||
) -> Result<HashMap<String, Value>> {
|
||||
for t in graph.initializer.iter() {
|
||||
let tensor = get_tensor(t, t.name.as_str())?;
|
||||
values.insert(t.name.to_string(), tensor);
|
||||
@ -958,6 +973,86 @@ pub fn simple_eval(
|
||||
let input = get(&node.input[0])?;
|
||||
values.insert(node.output[0].clone(), input.clone());
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#if
|
||||
"If" => {
|
||||
// protobuf encodes boolean false as 0 and true as 1
|
||||
let cond = get(&node.input[0])?.get(0)?.to_scalar::<u8>()?;
|
||||
let attr_name = if cond != 0 {
|
||||
"then_branch"
|
||||
} else {
|
||||
"else_branch"
|
||||
};
|
||||
let sub_graph = get_attr::<GraphProto>(node, attr_name)?;
|
||||
if sub_graph.output.len() != node.output.len() {
|
||||
bail!(
|
||||
"If node {:?} is malformed: branch outputs ({}) don't match node outputs ({})",
|
||||
node.name,
|
||||
sub_graph.output.len(),
|
||||
node.output.len()
|
||||
);
|
||||
}
|
||||
let branch_out = simple_eval_(sub_graph, values)?;
|
||||
for (i, out) in node.output.iter().enumerate() {
|
||||
values.insert(
|
||||
out.clone(),
|
||||
branch_out.get(&sub_graph.output[i].name).unwrap().clone(),
|
||||
);
|
||||
}
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#pad
|
||||
"Pad" => {
|
||||
let mode = get_attr_opt(node, "mode")?.unwrap_or("constant");
|
||||
let data = get(&node.input[0])?;
|
||||
let pads = get(&node.input[1])?;
|
||||
if node.input.len() > 2 {
|
||||
bail!(
|
||||
"unsupported number of inputs {} for Pad node {:?}, expected 2",
|
||||
node.input.len(),
|
||||
node.name
|
||||
);
|
||||
}
|
||||
if pads.rank() != 1 {
|
||||
bail!("Pad expects 'pads' input to be 1D vector: {pads:?}");
|
||||
}
|
||||
if pads.dim(0).unwrap() != 2 * data.rank() {
|
||||
bail!("Pad expects 'pads' input len to be 2 * rank of 'data' input: pads: {}, data rank: {}", pads, data.rank());
|
||||
}
|
||||
|
||||
let pads = pads.to_vec1::<i64>()?;
|
||||
let (pads_pre, pads_post) = pads.split_at(pads.len() / 2);
|
||||
|
||||
match mode {
|
||||
"reflect" => {
|
||||
let mut out = data.clone();
|
||||
for (i, &dim) in data.dims().iter().enumerate().rev() {
|
||||
if pads_pre[i] == 0 && pads_post[i] == 0 {
|
||||
continue;
|
||||
}
|
||||
fn zigzag(min: i64, max: i64) -> impl Iterator<Item = i64> {
|
||||
std::iter::repeat((min..max).chain((min + 1..=max).rev())).flatten()
|
||||
}
|
||||
let idx = if dim > 1 {
|
||||
let cycle_len = dim * 2 - 1;
|
||||
let skip = (pads_pre[i] as usize) % cycle_len;
|
||||
let idx = zigzag(0, (dim - 1) as i64)
|
||||
.skip(skip)
|
||||
.take((pads_pre[i] as usize) + dim + (pads_post[i] as usize));
|
||||
Tensor::from_iter(idx, out.device())?
|
||||
} else {
|
||||
Tensor::full(0i64, (dim,), out.device())?
|
||||
};
|
||||
|
||||
out = out.index_select(&idx, i)?;
|
||||
}
|
||||
|
||||
values.insert(node.output[0].clone(), out);
|
||||
}
|
||||
_ => bail!(
|
||||
"unsupported 'mode' value {mode:?} for Pad node {:?}",
|
||||
node.name
|
||||
),
|
||||
}
|
||||
}
|
||||
// https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13
|
||||
// TODO: This version is only compatible with ReduceMean V13 and below.
|
||||
"ReduceMean" => {
|
||||
|
Reference in New Issue
Block a user