implement if, and pad reflect mode (#2251)

* implement if, and pad reflect mode

The intent of this change is to allow eval of the current silero_vad.onnx (v4).
This onnx file uses 'If' and 'Pad' nodes, which had not been supported
by simple_eval until now

* Cleanup (fmt, clippy, minor test tweaks).

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
shua
2024-06-06 22:36:23 +02:00
committed by GitHub
parent f65e90e7ef
commit b9fac7ec00
2 changed files with 271 additions and 54 deletions

View File

@ -1,6 +1,6 @@
use crate::onnx;
use crate::onnx::attribute_proto::AttributeType;
use crate::onnx::tensor_proto::DataType;
use crate::onnx::{self, GraphProto};
use candle::{bail, DType, Device, Result, Tensor};
use std::{collections::HashMap, usize};
@ -56,6 +56,15 @@ impl Attr for str {
}
}
impl Attr for GraphProto {
const TYPE: AttributeType = AttributeType::Graph;
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
attr.g
.as_ref()
.ok_or_else(|| candle::Error::Msg("attribute does not contain graph".to_string()))
}
}
impl AttrOwned for Tensor {
const TYPE: AttributeType = AttributeType::Tensor;
fn get(attr: &onnx::AttributeProto) -> Result<Self> {
@ -214,13 +223,19 @@ pub fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
// anymore.
pub fn simple_eval(
model: &onnx::ModelProto,
inputs: HashMap<String, Value>,
mut inputs: HashMap<String, Value>,
) -> Result<HashMap<String, Value>> {
let graph = match &model.graph {
None => bail!("no graph defined in proto"),
Some(graph) => graph,
};
let mut values = inputs;
simple_eval_(graph, &mut inputs)
}
fn simple_eval_(
graph: &onnx::GraphProto,
values: &mut HashMap<String, Value>,
) -> Result<HashMap<String, Value>> {
for t in graph.initializer.iter() {
let tensor = get_tensor(t, t.name.as_str())?;
values.insert(t.name.to_string(), tensor);
@ -958,6 +973,86 @@ pub fn simple_eval(
let input = get(&node.input[0])?;
values.insert(node.output[0].clone(), input.clone());
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#if
"If" => {
// protobuf encodes boolean false as 0 and true as 1
let cond = get(&node.input[0])?.get(0)?.to_scalar::<u8>()?;
let attr_name = if cond != 0 {
"then_branch"
} else {
"else_branch"
};
let sub_graph = get_attr::<GraphProto>(node, attr_name)?;
if sub_graph.output.len() != node.output.len() {
bail!(
"If node {:?} is malformed: branch outputs ({}) don't match node outputs ({})",
node.name,
sub_graph.output.len(),
node.output.len()
);
}
let branch_out = simple_eval_(sub_graph, values)?;
for (i, out) in node.output.iter().enumerate() {
values.insert(
out.clone(),
branch_out.get(&sub_graph.output[i].name).unwrap().clone(),
);
}
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#pad
"Pad" => {
let mode = get_attr_opt(node, "mode")?.unwrap_or("constant");
let data = get(&node.input[0])?;
let pads = get(&node.input[1])?;
if node.input.len() > 2 {
bail!(
"unsupported number of inputs {} for Pad node {:?}, expected 2",
node.input.len(),
node.name
);
}
if pads.rank() != 1 {
bail!("Pad expects 'pads' input to be 1D vector: {pads:?}");
}
if pads.dim(0).unwrap() != 2 * data.rank() {
bail!("Pad expects 'pads' input len to be 2 * rank of 'data' input: pads: {}, data rank: {}", pads, data.rank());
}
let pads = pads.to_vec1::<i64>()?;
let (pads_pre, pads_post) = pads.split_at(pads.len() / 2);
match mode {
"reflect" => {
let mut out = data.clone();
for (i, &dim) in data.dims().iter().enumerate().rev() {
if pads_pre[i] == 0 && pads_post[i] == 0 {
continue;
}
fn zigzag(min: i64, max: i64) -> impl Iterator<Item = i64> {
std::iter::repeat((min..max).chain((min + 1..=max).rev())).flatten()
}
let idx = if dim > 1 {
let cycle_len = dim * 2 - 1;
let skip = (pads_pre[i] as usize) % cycle_len;
let idx = zigzag(0, (dim - 1) as i64)
.skip(skip)
.take((pads_pre[i] as usize) + dim + (pads_post[i] as usize));
Tensor::from_iter(idx, out.device())?
} else {
Tensor::full(0i64, (dim,), out.device())?
};
out = out.index_select(&idx, i)?;
}
values.insert(node.output[0].clone(), out);
}
_ => bail!(
"unsupported 'mode' value {mode:?} for Pad node {:?}",
node.name
),
}
}
// https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13
// TODO: This version is only compatible with ReduceMean V13 and below.
"ReduceMean" => {

View File

@ -4,10 +4,12 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::test_utils::to_vec2_round;
use candle::{DType, Device, NdArray, Result, Tensor};
use candle_onnx::onnx;
use candle_onnx::onnx::attribute_proto::AttributeType;
use candle_onnx::onnx::tensor_proto::DataType;
use candle_onnx::onnx::tensor_shape_proto::{dimension, Dimension};
use candle_onnx::onnx::{type_proto, TensorProto, TensorShapeProto, TypeProto};
use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto};
use std::collections::HashMap;
@ -35,14 +37,11 @@ fn create_model_proto_with_graph(graph: Option<GraphProto>) -> ModelProto {
#[test]
fn test_evaluation_fails_without_defined_graph() -> Result<()> {
let manual_graph = create_model_proto_with_graph(None);
let inputs: HashMap<String, Tensor> = HashMap::new();
match candle_onnx::simple_eval(&manual_graph, inputs) {
Err(err) => assert_eq!(err.to_string(), "no graph defined in proto"),
Ok(_) => panic!("Expected an error due to undefined graph"),
}
Ok(())
}
@ -81,14 +80,8 @@ fn test_add_operation() -> Result<()> {
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let first = z
.to_vec1::<f64>()?
.to_vec()
.get(0)
.expect("Failed to get first element")
.clone();
let first = z.to_vec1::<f64>()?[0];
assert_eq!(first, 4.0f64);
Ok(())
}
@ -127,14 +120,8 @@ fn test_sub_operation() -> Result<()> {
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let first = z
.to_vec1::<f64>()?
.to_vec()
.get(0)
.expect("Failed to get first element")
.clone();
let first = z.to_vec1::<f64>()?[0];
assert_eq!(first, 0.0f64);
Ok(())
}
@ -173,14 +160,8 @@ fn test_mul_operation() -> Result<()> {
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let first = z
.to_vec1::<f64>()?
.to_vec()
.get(0)
.expect("Failed to get first element")
.clone();
let first = z.to_vec1::<f64>()?[0];
assert_eq!(first, 4.0f64);
Ok(())
}
@ -219,15 +200,8 @@ fn test_div_operation() -> Result<()> {
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let first = z
.to_vec1::<f64>()?
.to_vec()
.get(0)
.expect("Failed to get first element")
.clone();
let first = z.to_vec1::<f64>()?[0];
assert_eq!(first, 1.0f64);
Ok(())
}
@ -272,7 +246,7 @@ fn test_exp_operation() -> Result<()> {
assert_eq!(results[0][0], 0.36787944f32);
assert_eq!(results[0][1], 1.0f32);
assert_eq!(results[1], vec![std::f32::consts::E, 7.38905609f32]);
assert_eq!(results[1], vec![std::f32::consts::E, 7.389056f32]);
Ok(())
}
@ -914,7 +888,7 @@ fn test_constant_of_shape() -> Result<()> {
),
_ => panic!("unsupported DType in test"),
};
let tensor = onnx::TensorProto {
let tensor = TensorProto {
data_type: data_type.into(),
dims: tensor.dims().iter().map(|v| *v as i64).collect(),
raw_data: value,
@ -1293,14 +1267,7 @@ fn test_cos_operation() -> Result<()> {
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(
results,
vec![vec![1.0, 0.54030234], vec![-0.41614684, -0.9899925]]
);
assert_eq!(to_vec2_round(z, 4)?, [[1.0, 0.5403], [-0.4161, -0.99]]);
Ok(())
}
@ -1342,19 +1309,12 @@ fn test_sin_operation() -> Result<()> {
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(results, vec![vec![0.0, 0.841471], vec![0.9092974, 0.14112]]);
assert_eq!(to_vec2_round(z, 4)?, [[0.0, 0.8415], [0.9093, 0.1411]]);
Ok(())
}
@ -3150,3 +3110,165 @@ fn test_leakyrelu() -> Result<()> {
Ok(())
}
// "If"
#[test]
fn test_if() -> Result<()> {
let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let y = vec![5.0, 4.0, 3.0, 2.0, 1.0];
let output_type_proto = Some(TypeProto {
value: Some(type_proto::Value::TensorType(type_proto::Tensor {
elem_type: DataType::Float.into(),
shape: Some(TensorShapeProto {
dim: vec![Dimension {
denotation: "".to_string(),
value: Some(dimension::Value::DimValue(5)),
}],
}),
})),
denotation: "".to_string(),
});
let then_branch = GraphProto {
output: vec![ValueInfoProto {
name: "then_out".to_string(),
r#type: output_type_proto.clone(),
doc_string: "".to_string(),
}],
node: vec![NodeProto {
op_type: "Constant".to_string(),
input: vec![],
output: vec!["then_out".to_string()],
attribute: vec![AttributeProto {
name: "value".to_string(),
r#type: AttributeType::Tensor.into(),
t: Some(TensorProto {
dims: vec![x.len() as i64],
float_data: x.clone(),
data_type: DataType::Float.into(),
..TensorProto::default()
}),
..AttributeProto::default()
}],
..NodeProto::default()
}],
..GraphProto::default()
};
let else_branch = GraphProto {
output: vec![ValueInfoProto {
name: "else_out".to_string(),
r#type: output_type_proto.clone(),
doc_string: "".to_string(),
}],
node: vec![NodeProto {
op_type: "Constant".to_string(),
input: vec![],
output: vec!["else_out".to_string()],
attribute: vec![AttributeProto {
name: "value".to_string(),
r#type: AttributeType::Tensor.into(),
t: Some(TensorProto {
dims: vec![y.len() as i64],
float_data: y.clone(),
data_type: DataType::Float.into(),
..TensorProto::default()
}),
..AttributeProto::default()
}],
..NodeProto::default()
}],
..GraphProto::default()
};
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "If".to_string(),
attribute: vec![
AttributeProto {
name: "then_branch".to_string(),
r#type: AttributeType::Graph.into(),
g: Some(then_branch),
..AttributeProto::default()
},
AttributeProto {
name: "else_branch".to_string(),
r#type: AttributeType::Graph.into(),
g: Some(else_branch),
..AttributeProto::default()
},
],
input: vec!["cond".to_string()],
output: vec!["res".to_string()],
..NodeProto::default()
}],
input: vec![],
output: vec![ValueInfoProto {
name: "res".to_string(),
doc_string: "".to_string(),
r#type: output_type_proto.clone(),
}],
..GraphProto::default()
}));
for cond in [1u8, 0] {
let inputs =
HashMap::from_iter([("cond".to_string(), Tensor::full(cond, (1,), &Device::Cpu)?)]);
let outputs = candle_onnx::simple_eval(&manual_graph, inputs)?;
let expected = if cond != 0 { &x } else { &y };
let Some(res) = outputs.get("res") else {
candle::bail!("outputs didn't contain expected key `res`: {outputs:?}");
};
assert_eq!(&res.to_vec1::<f32>()?, expected);
}
Ok(())
}
#[test]
fn test_pad() -> Result<()> {
let data = Tensor::from_vec(vec![1.0, 1.2, 2.3, 3.4, 4.5, 5.7], (3, 2), &Device::Cpu)?;
let pads = Tensor::from_vec(vec![0i64, 2, 0, 0], (4,), &Device::Cpu)?;
let mode = "reflect";
let expected = Tensor::from_vec(
vec![1.0, 1.2, 1.0, 1.2, 2.3, 3.4, 2.3, 3.4, 4.5, 5.7, 4.5, 5.7],
(3, 4),
&Device::Cpu,
)?;
let model = create_model_proto_with_graph(Some(GraphProto {
input: vec![
ValueInfoProto {
name: "data".to_string(),
..ValueInfoProto::default()
},
ValueInfoProto {
name: "pads".to_string(),
..ValueInfoProto::default()
},
],
output: vec![ValueInfoProto {
name: "output".to_string(),
..ValueInfoProto::default()
}],
node: vec![NodeProto {
op_type: "Pad".to_string(),
input: vec!["data".to_string(), "pads".to_string()],
output: vec!["output".to_string()],
attribute: vec![AttributeProto {
name: "mode".to_string(),
r#type: AttributeType::String.into(),
s: mode.as_bytes().to_vec(),
..AttributeProto::default()
}],
..NodeProto::default()
}],
..GraphProto::default()
}));
let inputs = HashMap::from_iter([("data".to_string(), data), ("pads".to_string(), pads)]);
let res = candle_onnx::simple_eval(&model, inputs)?;
let Some(actual) = res.get("output") else {
candle::bail!("outputs didn't contain expected key `output`: {res:?}");
};
assert_eq!(actual.to_vec2::<f64>()?, expected.to_vec2::<f64>()?);
Ok(())
}