mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
implement if, and pad reflect mode (#2251)
* implement if, and pad reflect mode The intent of this change is to allow eval of the current silero_vad.onnx (v4). This onnx file uses 'If' and 'Pad' nodes, which had not been supported by simple_eval until now * Cleanup (fmt, clippy, minor test tweaks). --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -4,10 +4,12 @@ extern crate intel_mkl_src;
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use candle::test_utils::to_vec2_round;
|
||||
use candle::{DType, Device, NdArray, Result, Tensor};
|
||||
use candle_onnx::onnx;
|
||||
use candle_onnx::onnx::attribute_proto::AttributeType;
|
||||
use candle_onnx::onnx::tensor_proto::DataType;
|
||||
use candle_onnx::onnx::tensor_shape_proto::{dimension, Dimension};
|
||||
use candle_onnx::onnx::{type_proto, TensorProto, TensorShapeProto, TypeProto};
|
||||
use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto};
|
||||
use std::collections::HashMap;
|
||||
|
||||
@ -35,14 +37,11 @@ fn create_model_proto_with_graph(graph: Option<GraphProto>) -> ModelProto {
|
||||
#[test]
|
||||
fn test_evaluation_fails_without_defined_graph() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(None);
|
||||
|
||||
let inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
|
||||
match candle_onnx::simple_eval(&manual_graph, inputs) {
|
||||
Err(err) => assert_eq!(err.to_string(), "no graph defined in proto"),
|
||||
Ok(_) => panic!("Expected an error due to undefined graph"),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -81,14 +80,8 @@ fn test_add_operation() -> Result<()> {
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let first = z
|
||||
.to_vec1::<f64>()?
|
||||
.to_vec()
|
||||
.get(0)
|
||||
.expect("Failed to get first element")
|
||||
.clone();
|
||||
let first = z.to_vec1::<f64>()?[0];
|
||||
assert_eq!(first, 4.0f64);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -127,14 +120,8 @@ fn test_sub_operation() -> Result<()> {
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let first = z
|
||||
.to_vec1::<f64>()?
|
||||
.to_vec()
|
||||
.get(0)
|
||||
.expect("Failed to get first element")
|
||||
.clone();
|
||||
let first = z.to_vec1::<f64>()?[0];
|
||||
assert_eq!(first, 0.0f64);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -173,14 +160,8 @@ fn test_mul_operation() -> Result<()> {
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let first = z
|
||||
.to_vec1::<f64>()?
|
||||
.to_vec()
|
||||
.get(0)
|
||||
.expect("Failed to get first element")
|
||||
.clone();
|
||||
let first = z.to_vec1::<f64>()?[0];
|
||||
assert_eq!(first, 4.0f64);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -219,15 +200,8 @@ fn test_div_operation() -> Result<()> {
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let first = z
|
||||
.to_vec1::<f64>()?
|
||||
.to_vec()
|
||||
.get(0)
|
||||
.expect("Failed to get first element")
|
||||
.clone();
|
||||
|
||||
let first = z.to_vec1::<f64>()?[0];
|
||||
assert_eq!(first, 1.0f64);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -272,7 +246,7 @@ fn test_exp_operation() -> Result<()> {
|
||||
|
||||
assert_eq!(results[0][0], 0.36787944f32);
|
||||
assert_eq!(results[0][1], 1.0f32);
|
||||
assert_eq!(results[1], vec![std::f32::consts::E, 7.38905609f32]);
|
||||
assert_eq!(results[1], vec![std::f32::consts::E, 7.389056f32]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@ -914,7 +888,7 @@ fn test_constant_of_shape() -> Result<()> {
|
||||
),
|
||||
_ => panic!("unsupported DType in test"),
|
||||
};
|
||||
let tensor = onnx::TensorProto {
|
||||
let tensor = TensorProto {
|
||||
data_type: data_type.into(),
|
||||
dims: tensor.dims().iter().map(|v| *v as i64).collect(),
|
||||
raw_data: value,
|
||||
@ -1293,14 +1267,7 @@ fn test_cos_operation() -> Result<()> {
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let results = z.to_vec2::<f32>()?;
|
||||
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![vec![1.0, 0.54030234], vec![-0.41614684, -0.9899925]]
|
||||
);
|
||||
|
||||
assert_eq!(to_vec2_round(z, 4)?, [[1.0, 0.5403], [-0.4161, -0.99]]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1342,19 +1309,12 @@ fn test_sin_operation() -> Result<()> {
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let results = z.to_vec2::<f32>()?;
|
||||
|
||||
assert_eq!(results, vec![vec![0.0, 0.841471], vec![0.9092974, 0.14112]]);
|
||||
|
||||
assert_eq!(to_vec2_round(z, 4)?, [[0.0, 0.8415], [0.9093, 0.1411]]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -3150,3 +3110,165 @@ fn test_leakyrelu() -> Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "If"
|
||||
#[test]
|
||||
fn test_if() -> Result<()> {
|
||||
let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let y = vec![5.0, 4.0, 3.0, 2.0, 1.0];
|
||||
let output_type_proto = Some(TypeProto {
|
||||
value: Some(type_proto::Value::TensorType(type_proto::Tensor {
|
||||
elem_type: DataType::Float.into(),
|
||||
shape: Some(TensorShapeProto {
|
||||
dim: vec![Dimension {
|
||||
denotation: "".to_string(),
|
||||
value: Some(dimension::Value::DimValue(5)),
|
||||
}],
|
||||
}),
|
||||
})),
|
||||
denotation: "".to_string(),
|
||||
});
|
||||
let then_branch = GraphProto {
|
||||
output: vec![ValueInfoProto {
|
||||
name: "then_out".to_string(),
|
||||
r#type: output_type_proto.clone(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
node: vec![NodeProto {
|
||||
op_type: "Constant".to_string(),
|
||||
input: vec![],
|
||||
output: vec!["then_out".to_string()],
|
||||
attribute: vec![AttributeProto {
|
||||
name: "value".to_string(),
|
||||
r#type: AttributeType::Tensor.into(),
|
||||
t: Some(TensorProto {
|
||||
dims: vec![x.len() as i64],
|
||||
float_data: x.clone(),
|
||||
data_type: DataType::Float.into(),
|
||||
..TensorProto::default()
|
||||
}),
|
||||
..AttributeProto::default()
|
||||
}],
|
||||
..NodeProto::default()
|
||||
}],
|
||||
..GraphProto::default()
|
||||
};
|
||||
let else_branch = GraphProto {
|
||||
output: vec![ValueInfoProto {
|
||||
name: "else_out".to_string(),
|
||||
r#type: output_type_proto.clone(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
node: vec![NodeProto {
|
||||
op_type: "Constant".to_string(),
|
||||
input: vec![],
|
||||
output: vec!["else_out".to_string()],
|
||||
attribute: vec![AttributeProto {
|
||||
name: "value".to_string(),
|
||||
r#type: AttributeType::Tensor.into(),
|
||||
t: Some(TensorProto {
|
||||
dims: vec![y.len() as i64],
|
||||
float_data: y.clone(),
|
||||
data_type: DataType::Float.into(),
|
||||
..TensorProto::default()
|
||||
}),
|
||||
..AttributeProto::default()
|
||||
}],
|
||||
..NodeProto::default()
|
||||
}],
|
||||
..GraphProto::default()
|
||||
};
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "If".to_string(),
|
||||
attribute: vec![
|
||||
AttributeProto {
|
||||
name: "then_branch".to_string(),
|
||||
r#type: AttributeType::Graph.into(),
|
||||
g: Some(then_branch),
|
||||
..AttributeProto::default()
|
||||
},
|
||||
AttributeProto {
|
||||
name: "else_branch".to_string(),
|
||||
r#type: AttributeType::Graph.into(),
|
||||
g: Some(else_branch),
|
||||
..AttributeProto::default()
|
||||
},
|
||||
],
|
||||
input: vec!["cond".to_string()],
|
||||
output: vec!["res".to_string()],
|
||||
..NodeProto::default()
|
||||
}],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: "res".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: output_type_proto.clone(),
|
||||
}],
|
||||
..GraphProto::default()
|
||||
}));
|
||||
|
||||
for cond in [1u8, 0] {
|
||||
let inputs =
|
||||
HashMap::from_iter([("cond".to_string(), Tensor::full(cond, (1,), &Device::Cpu)?)]);
|
||||
let outputs = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
let expected = if cond != 0 { &x } else { &y };
|
||||
let Some(res) = outputs.get("res") else {
|
||||
candle::bail!("outputs didn't contain expected key `res`: {outputs:?}");
|
||||
};
|
||||
assert_eq!(&res.to_vec1::<f32>()?, expected);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pad() -> Result<()> {
|
||||
let data = Tensor::from_vec(vec![1.0, 1.2, 2.3, 3.4, 4.5, 5.7], (3, 2), &Device::Cpu)?;
|
||||
let pads = Tensor::from_vec(vec![0i64, 2, 0, 0], (4,), &Device::Cpu)?;
|
||||
let mode = "reflect";
|
||||
|
||||
let expected = Tensor::from_vec(
|
||||
vec![1.0, 1.2, 1.0, 1.2, 2.3, 3.4, 2.3, 3.4, 4.5, 5.7, 4.5, 5.7],
|
||||
(3, 4),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let model = create_model_proto_with_graph(Some(GraphProto {
|
||||
input: vec![
|
||||
ValueInfoProto {
|
||||
name: "data".to_string(),
|
||||
..ValueInfoProto::default()
|
||||
},
|
||||
ValueInfoProto {
|
||||
name: "pads".to_string(),
|
||||
..ValueInfoProto::default()
|
||||
},
|
||||
],
|
||||
output: vec![ValueInfoProto {
|
||||
name: "output".to_string(),
|
||||
..ValueInfoProto::default()
|
||||
}],
|
||||
node: vec![NodeProto {
|
||||
op_type: "Pad".to_string(),
|
||||
input: vec!["data".to_string(), "pads".to_string()],
|
||||
output: vec!["output".to_string()],
|
||||
attribute: vec![AttributeProto {
|
||||
name: "mode".to_string(),
|
||||
r#type: AttributeType::String.into(),
|
||||
s: mode.as_bytes().to_vec(),
|
||||
..AttributeProto::default()
|
||||
}],
|
||||
..NodeProto::default()
|
||||
}],
|
||||
..GraphProto::default()
|
||||
}));
|
||||
|
||||
let inputs = HashMap::from_iter([("data".to_string(), data), ("pads".to_string(), pads)]);
|
||||
let res = candle_onnx::simple_eval(&model, inputs)?;
|
||||
let Some(actual) = res.get("output") else {
|
||||
candle::bail!("outputs didn't contain expected key `output`: {res:?}");
|
||||
};
|
||||
|
||||
assert_eq!(actual.to_vec2::<f64>()?, expected.to_vec2::<f64>()?);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user