mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
implement if, and pad reflect mode (#2251)
* implement if, and pad reflect mode The intent of this change is to allow eval of the current silero_vad.onnx (v4). This onnx file uses 'If' and 'Pad' nodes, which had not been supported by simple_eval until now * Cleanup (fmt, clippy, minor test tweaks). --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
use crate::onnx;
|
||||
use crate::onnx::attribute_proto::AttributeType;
|
||||
use crate::onnx::tensor_proto::DataType;
|
||||
use crate::onnx::{self, GraphProto};
|
||||
use candle::{bail, DType, Device, Result, Tensor};
|
||||
use std::{collections::HashMap, usize};
|
||||
|
||||
@ -56,6 +56,15 @@ impl Attr for str {
|
||||
}
|
||||
}
|
||||
|
||||
impl Attr for GraphProto {
|
||||
const TYPE: AttributeType = AttributeType::Graph;
|
||||
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
|
||||
attr.g
|
||||
.as_ref()
|
||||
.ok_or_else(|| candle::Error::Msg("attribute does not contain graph".to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
impl AttrOwned for Tensor {
|
||||
const TYPE: AttributeType = AttributeType::Tensor;
|
||||
fn get(attr: &onnx::AttributeProto) -> Result<Self> {
|
||||
@ -214,13 +223,19 @@ pub fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
|
||||
// anymore.
|
||||
pub fn simple_eval(
|
||||
model: &onnx::ModelProto,
|
||||
inputs: HashMap<String, Value>,
|
||||
mut inputs: HashMap<String, Value>,
|
||||
) -> Result<HashMap<String, Value>> {
|
||||
let graph = match &model.graph {
|
||||
None => bail!("no graph defined in proto"),
|
||||
Some(graph) => graph,
|
||||
};
|
||||
let mut values = inputs;
|
||||
simple_eval_(graph, &mut inputs)
|
||||
}
|
||||
|
||||
fn simple_eval_(
|
||||
graph: &onnx::GraphProto,
|
||||
values: &mut HashMap<String, Value>,
|
||||
) -> Result<HashMap<String, Value>> {
|
||||
for t in graph.initializer.iter() {
|
||||
let tensor = get_tensor(t, t.name.as_str())?;
|
||||
values.insert(t.name.to_string(), tensor);
|
||||
@ -958,6 +973,86 @@ pub fn simple_eval(
|
||||
let input = get(&node.input[0])?;
|
||||
values.insert(node.output[0].clone(), input.clone());
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#if
|
||||
"If" => {
|
||||
// protobuf encodes boolean false as 0 and true as 1
|
||||
let cond = get(&node.input[0])?.get(0)?.to_scalar::<u8>()?;
|
||||
let attr_name = if cond != 0 {
|
||||
"then_branch"
|
||||
} else {
|
||||
"else_branch"
|
||||
};
|
||||
let sub_graph = get_attr::<GraphProto>(node, attr_name)?;
|
||||
if sub_graph.output.len() != node.output.len() {
|
||||
bail!(
|
||||
"If node {:?} is malformed: branch outputs ({}) don't match node outputs ({})",
|
||||
node.name,
|
||||
sub_graph.output.len(),
|
||||
node.output.len()
|
||||
);
|
||||
}
|
||||
let branch_out = simple_eval_(sub_graph, values)?;
|
||||
for (i, out) in node.output.iter().enumerate() {
|
||||
values.insert(
|
||||
out.clone(),
|
||||
branch_out.get(&sub_graph.output[i].name).unwrap().clone(),
|
||||
);
|
||||
}
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#pad
|
||||
"Pad" => {
|
||||
let mode = get_attr_opt(node, "mode")?.unwrap_or("constant");
|
||||
let data = get(&node.input[0])?;
|
||||
let pads = get(&node.input[1])?;
|
||||
if node.input.len() > 2 {
|
||||
bail!(
|
||||
"unsupported number of inputs {} for Pad node {:?}, expected 2",
|
||||
node.input.len(),
|
||||
node.name
|
||||
);
|
||||
}
|
||||
if pads.rank() != 1 {
|
||||
bail!("Pad expects 'pads' input to be 1D vector: {pads:?}");
|
||||
}
|
||||
if pads.dim(0).unwrap() != 2 * data.rank() {
|
||||
bail!("Pad expects 'pads' input len to be 2 * rank of 'data' input: pads: {}, data rank: {}", pads, data.rank());
|
||||
}
|
||||
|
||||
let pads = pads.to_vec1::<i64>()?;
|
||||
let (pads_pre, pads_post) = pads.split_at(pads.len() / 2);
|
||||
|
||||
match mode {
|
||||
"reflect" => {
|
||||
let mut out = data.clone();
|
||||
for (i, &dim) in data.dims().iter().enumerate().rev() {
|
||||
if pads_pre[i] == 0 && pads_post[i] == 0 {
|
||||
continue;
|
||||
}
|
||||
fn zigzag(min: i64, max: i64) -> impl Iterator<Item = i64> {
|
||||
std::iter::repeat((min..max).chain((min + 1..=max).rev())).flatten()
|
||||
}
|
||||
let idx = if dim > 1 {
|
||||
let cycle_len = dim * 2 - 1;
|
||||
let skip = (pads_pre[i] as usize) % cycle_len;
|
||||
let idx = zigzag(0, (dim - 1) as i64)
|
||||
.skip(skip)
|
||||
.take((pads_pre[i] as usize) + dim + (pads_post[i] as usize));
|
||||
Tensor::from_iter(idx, out.device())?
|
||||
} else {
|
||||
Tensor::full(0i64, (dim,), out.device())?
|
||||
};
|
||||
|
||||
out = out.index_select(&idx, i)?;
|
||||
}
|
||||
|
||||
values.insert(node.output[0].clone(), out);
|
||||
}
|
||||
_ => bail!(
|
||||
"unsupported 'mode' value {mode:?} for Pad node {:?}",
|
||||
node.name
|
||||
),
|
||||
}
|
||||
}
|
||||
// https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13
|
||||
// TODO: This version is only compatible with ReduceMean V13 and below.
|
||||
"ReduceMean" => {
|
||||
|
@ -4,10 +4,12 @@ extern crate intel_mkl_src;
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle::test_utils::to_vec2_round;
|
||||
use candle::{DType, Device, NdArray, Result, Tensor};
|
||||
use candle_onnx::onnx;
|
||||
use candle_onnx::onnx::attribute_proto::AttributeType;
|
||||
use candle_onnx::onnx::tensor_proto::DataType;
|
||||
use candle_onnx::onnx::tensor_shape_proto::{dimension, Dimension};
|
||||
use candle_onnx::onnx::{type_proto, TensorProto, TensorShapeProto, TypeProto};
|
||||
use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto};
|
||||
use std::collections::HashMap;
|
||||
|
||||
@ -35,14 +37,11 @@ fn create_model_proto_with_graph(graph: Option<GraphProto>) -> ModelProto {
|
||||
#[test]
|
||||
fn test_evaluation_fails_without_defined_graph() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(None);
|
||||
|
||||
let inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
|
||||
match candle_onnx::simple_eval(&manual_graph, inputs) {
|
||||
Err(err) => assert_eq!(err.to_string(), "no graph defined in proto"),
|
||||
Ok(_) => panic!("Expected an error due to undefined graph"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -81,14 +80,8 @@ fn test_add_operation() -> Result<()> {
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let first = z
|
||||
.to_vec1::<f64>()?
|
||||
.to_vec()
|
||||
.get(0)
|
||||
.expect("Failed to get first element")
|
||||
.clone();
|
||||
let first = z.to_vec1::<f64>()?[0];
|
||||
assert_eq!(first, 4.0f64);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -127,14 +120,8 @@ fn test_sub_operation() -> Result<()> {
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let first = z
|
||||
.to_vec1::<f64>()?
|
||||
.to_vec()
|
||||
.get(0)
|
||||
.expect("Failed to get first element")
|
||||
.clone();
|
||||
let first = z.to_vec1::<f64>()?[0];
|
||||
assert_eq!(first, 0.0f64);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -173,14 +160,8 @@ fn test_mul_operation() -> Result<()> {
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let first = z
|
||||
.to_vec1::<f64>()?
|
||||
.to_vec()
|
||||
.get(0)
|
||||
.expect("Failed to get first element")
|
||||
.clone();
|
||||
let first = z.to_vec1::<f64>()?[0];
|
||||
assert_eq!(first, 4.0f64);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -219,15 +200,8 @@ fn test_div_operation() -> Result<()> {
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let first = z
|
||||
.to_vec1::<f64>()?
|
||||
.to_vec()
|
||||
.get(0)
|
||||
.expect("Failed to get first element")
|
||||
.clone();
|
||||
|
||||
let first = z.to_vec1::<f64>()?[0];
|
||||
assert_eq!(first, 1.0f64);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -272,7 +246,7 @@ fn test_exp_operation() -> Result<()> {
|
||||
|
||||
assert_eq!(results[0][0], 0.36787944f32);
|
||||
assert_eq!(results[0][1], 1.0f32);
|
||||
assert_eq!(results[1], vec![std::f32::consts::E, 7.38905609f32]);
|
||||
assert_eq!(results[1], vec![std::f32::consts::E, 7.389056f32]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -914,7 +888,7 @@ fn test_constant_of_shape() -> Result<()> {
|
||||
),
|
||||
_ => panic!("unsupported DType in test"),
|
||||
};
|
||||
let tensor = onnx::TensorProto {
|
||||
let tensor = TensorProto {
|
||||
data_type: data_type.into(),
|
||||
dims: tensor.dims().iter().map(|v| *v as i64).collect(),
|
||||
raw_data: value,
|
||||
@ -1293,14 +1267,7 @@ fn test_cos_operation() -> Result<()> {
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let results = z.to_vec2::<f32>()?;
|
||||
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![vec![1.0, 0.54030234], vec![-0.41614684, -0.9899925]]
|
||||
);
|
||||
|
||||
assert_eq!(to_vec2_round(z, 4)?, [[1.0, 0.5403], [-0.4161, -0.99]]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1342,19 +1309,12 @@ fn test_sin_operation() -> Result<()> {
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let results = z.to_vec2::<f32>()?;
|
||||
|
||||
assert_eq!(results, vec![vec![0.0, 0.841471], vec![0.9092974, 0.14112]]);
|
||||
|
||||
assert_eq!(to_vec2_round(z, 4)?, [[0.0, 0.8415], [0.9093, 0.1411]]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -3150,3 +3110,165 @@ fn test_leakyrelu() -> Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "If"
|
||||
#[test]
|
||||
fn test_if() -> Result<()> {
|
||||
let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let y = vec![5.0, 4.0, 3.0, 2.0, 1.0];
|
||||
let output_type_proto = Some(TypeProto {
|
||||
value: Some(type_proto::Value::TensorType(type_proto::Tensor {
|
||||
elem_type: DataType::Float.into(),
|
||||
shape: Some(TensorShapeProto {
|
||||
dim: vec![Dimension {
|
||||
denotation: "".to_string(),
|
||||
value: Some(dimension::Value::DimValue(5)),
|
||||
}],
|
||||
}),
|
||||
})),
|
||||
denotation: "".to_string(),
|
||||
});
|
||||
let then_branch = GraphProto {
|
||||
output: vec![ValueInfoProto {
|
||||
name: "then_out".to_string(),
|
||||
r#type: output_type_proto.clone(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
node: vec![NodeProto {
|
||||
op_type: "Constant".to_string(),
|
||||
input: vec![],
|
||||
output: vec!["then_out".to_string()],
|
||||
attribute: vec![AttributeProto {
|
||||
name: "value".to_string(),
|
||||
r#type: AttributeType::Tensor.into(),
|
||||
t: Some(TensorProto {
|
||||
dims: vec![x.len() as i64],
|
||||
float_data: x.clone(),
|
||||
data_type: DataType::Float.into(),
|
||||
..TensorProto::default()
|
||||
}),
|
||||
..AttributeProto::default()
|
||||
}],
|
||||
..NodeProto::default()
|
||||
}],
|
||||
..GraphProto::default()
|
||||
};
|
||||
let else_branch = GraphProto {
|
||||
output: vec![ValueInfoProto {
|
||||
name: "else_out".to_string(),
|
||||
r#type: output_type_proto.clone(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
node: vec![NodeProto {
|
||||
op_type: "Constant".to_string(),
|
||||
input: vec![],
|
||||
output: vec!["else_out".to_string()],
|
||||
attribute: vec![AttributeProto {
|
||||
name: "value".to_string(),
|
||||
r#type: AttributeType::Tensor.into(),
|
||||
t: Some(TensorProto {
|
||||
dims: vec![y.len() as i64],
|
||||
float_data: y.clone(),
|
||||
data_type: DataType::Float.into(),
|
||||
..TensorProto::default()
|
||||
}),
|
||||
..AttributeProto::default()
|
||||
}],
|
||||
..NodeProto::default()
|
||||
}],
|
||||
..GraphProto::default()
|
||||
};
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "If".to_string(),
|
||||
attribute: vec![
|
||||
AttributeProto {
|
||||
name: "then_branch".to_string(),
|
||||
r#type: AttributeType::Graph.into(),
|
||||
g: Some(then_branch),
|
||||
..AttributeProto::default()
|
||||
},
|
||||
AttributeProto {
|
||||
name: "else_branch".to_string(),
|
||||
r#type: AttributeType::Graph.into(),
|
||||
g: Some(else_branch),
|
||||
..AttributeProto::default()
|
||||
},
|
||||
],
|
||||
input: vec!["cond".to_string()],
|
||||
output: vec!["res".to_string()],
|
||||
..NodeProto::default()
|
||||
}],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: "res".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: output_type_proto.clone(),
|
||||
}],
|
||||
..GraphProto::default()
|
||||
}));
|
||||
|
||||
for cond in [1u8, 0] {
|
||||
let inputs =
|
||||
HashMap::from_iter([("cond".to_string(), Tensor::full(cond, (1,), &Device::Cpu)?)]);
|
||||
let outputs = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
let expected = if cond != 0 { &x } else { &y };
|
||||
let Some(res) = outputs.get("res") else {
|
||||
candle::bail!("outputs didn't contain expected key `res`: {outputs:?}");
|
||||
};
|
||||
assert_eq!(&res.to_vec1::<f32>()?, expected);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pad() -> Result<()> {
|
||||
let data = Tensor::from_vec(vec![1.0, 1.2, 2.3, 3.4, 4.5, 5.7], (3, 2), &Device::Cpu)?;
|
||||
let pads = Tensor::from_vec(vec![0i64, 2, 0, 0], (4,), &Device::Cpu)?;
|
||||
let mode = "reflect";
|
||||
|
||||
let expected = Tensor::from_vec(
|
||||
vec![1.0, 1.2, 1.0, 1.2, 2.3, 3.4, 2.3, 3.4, 4.5, 5.7, 4.5, 5.7],
|
||||
(3, 4),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let model = create_model_proto_with_graph(Some(GraphProto {
|
||||
input: vec![
|
||||
ValueInfoProto {
|
||||
name: "data".to_string(),
|
||||
..ValueInfoProto::default()
|
||||
},
|
||||
ValueInfoProto {
|
||||
name: "pads".to_string(),
|
||||
..ValueInfoProto::default()
|
||||
},
|
||||
],
|
||||
output: vec![ValueInfoProto {
|
||||
name: "output".to_string(),
|
||||
..ValueInfoProto::default()
|
||||
}],
|
||||
node: vec![NodeProto {
|
||||
op_type: "Pad".to_string(),
|
||||
input: vec!["data".to_string(), "pads".to_string()],
|
||||
output: vec!["output".to_string()],
|
||||
attribute: vec![AttributeProto {
|
||||
name: "mode".to_string(),
|
||||
r#type: AttributeType::String.into(),
|
||||
s: mode.as_bytes().to_vec(),
|
||||
..AttributeProto::default()
|
||||
}],
|
||||
..NodeProto::default()
|
||||
}],
|
||||
..GraphProto::default()
|
||||
}));
|
||||
|
||||
let inputs = HashMap::from_iter([("data".to_string(), data), ("pads".to_string(), pads)]);
|
||||
let res = candle_onnx::simple_eval(&model, inputs)?;
|
||||
let Some(actual) = res.get("output") else {
|
||||
candle::bail!("outputs didn't contain expected key `output`: {res:?}");
|
||||
};
|
||||
|
||||
assert_eq!(actual.to_vec2::<f64>()?, expected.to_vec2::<f64>()?);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user