Add missing onnx operations (#2096)

* Add missing onnx operations

* Add tests and fix errors

* Run rustfmt
This commit is contained in:
Gabriel
2024-04-20 18:44:22 +02:00
committed by GitHub
parent 52ae332910
commit 9215e9ce8c
2 changed files with 736 additions and 9 deletions

View File

@ -23,6 +23,11 @@ trait Attr {
fn get(attr: &onnx::AttributeProto) -> Result<&Self>;
}
trait AttrOwned: Sized {
const TYPE: AttributeType;
fn get(attr: &onnx::AttributeProto) -> Result<Self>;
}
impl Attr for i64 {
const TYPE: AttributeType = AttributeType::Int;
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
@ -51,6 +56,50 @@ impl Attr for str {
}
}
impl AttrOwned for Tensor {
const TYPE: AttributeType = AttributeType::Tensor;
fn get(attr: &onnx::AttributeProto) -> Result<Self> {
let tensor_proto = match &attr.t {
Some(value) => value,
None => bail!(
"attribute {} was of type TENSOR, but no tensor was found",
attr.name
),
};
let data_type = match DataType::try_from(tensor_proto.data_type) {
Ok(value) => value,
Err(_) => bail!(
"attribute {} of type TENSOR was an invalid data_type number {}",
attr.name,
tensor_proto.data_type
),
};
let dtype = match dtype(data_type) {
Some(value) => value,
None => bail!(
"attribute {} of type TENSOR has an unsupported data_type {}",
attr.name,
data_type.as_str_name()
),
};
let mut dims = Vec::with_capacity(tensor_proto.dims.len());
for dim in &tensor_proto.dims {
if dim < &0 {
bail!(
"attribute {} of type TENSOR has a negative dimension, which is unsupported",
attr.name
)
}
dims.push(*dim as usize)
}
Tensor::from_raw_buffer(&tensor_proto.raw_data, dtype, &dims, &Device::Cpu)
}
}
fn get_attr_<'a>(node: &'a onnx::NodeProto, name: &str) -> Result<&'a onnx::AttributeProto> {
match node.attribute.iter().find(|attr| attr.name == name) {
None => {
@ -98,6 +147,24 @@ fn get_attr_opt<'a, T: Attr + ?Sized>(
}
}
fn get_attr_opt_owned<T: AttrOwned>(node: &onnx::NodeProto, name: &str) -> Result<Option<T>> {
match node.attribute.iter().find(|attr| attr.name == name) {
None => Ok(None),
Some(attr) => {
if attr.r#type() != T::TYPE {
bail!(
"unsupported type {:?} for '{name}' attribute in '{}' for {}",
attr.r#type,
node.op_type,
node.name
)
}
let val = T::get(attr)?;
Ok(Some(val))
}
}
}
pub fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect();
match DataType::try_from(t.data_type) {
@ -458,14 +525,17 @@ pub fn simple_eval(
}
values.insert(node.output[0].clone(), xs);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#ConstantOfShape
"ConstantOfShape" => {
let dims = get(&node.input[0])?;
let shape = dims
.to_vec1::<i64>()?
.into_iter()
.map(|v| v as usize)
.collect::<Vec<_>>();
let xs = Tensor::zeros(shape, DType::F32, dims.device())?;
let input = get(&node.input[0])?;
let value = get_attr_opt_owned::<Tensor>(node, "value")?.unwrap_or(Tensor::zeros(
(),
DType::F32,
&Device::Cpu,
)?);
let xs = Tensor::ones(input.shape(), value.dtype(), input.device())?
.broadcast_mul(&value)?;
values.insert(node.output[0].clone(), xs);
}
"Unsqueeze" => {
@ -552,6 +622,82 @@ pub fn simple_eval(
let dims = Tensor::from_vec(dims, xs.rank(), xs.device())?;
values.insert(node.output[0].clone(), dims);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sqrt
"Sqrt" => {
let xs = get(&node.input[0])?;
let output = xs.sqrt()?;
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Range
"Range" => {
let start = get(&node.input[0])?;
let limit = get(&node.input[1])?;
let delta = get(&node.input[2])?;
macro_rules! arange_step {
($t: ty) => {
Tensor::arange_step(
start.to_vec0::<$t>()?,
limit.to_vec0::<$t>()?,
delta.to_vec0::<$t>()?,
&Device::Cpu,
)?
};
}
let output = match start.dtype() {
DType::U8 => arange_step!(u8),
DType::U32 => arange_step!(u32),
DType::I64 => arange_step!(i64),
DType::BF16 => arange_step!(f32),
DType::F16 => arange_step!(f32),
DType::F32 => arange_step!(f32),
DType::F64 => arange_step!(f64),
};
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Greater
"Greater" => {
let a = get(&node.input[0])?;
let b = get(&node.input[1])?;
let output = a.broadcast_gt(b)?;
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Less
"Less" => {
let a = get(&node.input[0])?;
let b = get(&node.input[1])?;
let output = a.broadcast_lt(b)?;
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Log
"Log" => {
let a = get(&node.input[0])?;
let output = a.log()?;
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Min
"Min" => {
let mut output = get(&node.input[0])?.clone();
for input in node.input.iter() {
let input = get(input)?;
output = output.broadcast_minimum(input)?
}
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Where
"Where" => {
let cond = get(&node.input[0])?;
let a = get(&node.input[1])?;
let b = get(&node.input[2])?;
let output = cond.where_cond(a, b)?;
values.insert(node.output[0].clone(), output);
}
"Conv" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;

View File

@ -4,12 +4,16 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::{Device, NdArray, Result, Tensor};
use candle::{DType, Device, NdArray, Result, Tensor};
use candle_onnx::onnx;
use candle_onnx::onnx::attribute_proto::AttributeType;
use candle_onnx::onnx::tensor_proto::DataType;
use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto};
use std::collections::HashMap;
const INPUT_X: &str = "x";
const INPUT_Y: &str = "y";
const INPUT_A: &str = "a";
const OUTPUT_Z: &str = "z";
fn create_model_proto_with_graph(graph: Option<GraphProto>) -> ModelProto {
@ -820,7 +824,137 @@ fn test_flatten_operation() -> Result<()> {
// #[test]
// "ConstantOfShape"
// #[test]
#[test]
fn test_constant_of_shape() -> Result<()> {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-31
test(&[4i64, 3, 2], Some(1.), &[1., 1., 1.])?;
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-31
test(&[0.], Some(0i64), &[0i64])?;
// "value" defaults to 0 f32
test(&[1i64, 2, 3, 4], None as Option<i64>, &[0., 0., 0., 0.])?;
fn test(
input: impl NdArray,
value: Option<impl NdArray>,
expected: impl NdArray,
) -> Result<()> {
let mut attribute = vec![];
if let Some(value) = value {
let tensor = Tensor::new(value, &Device::Cpu)?;
let (value, data_type) = match tensor.dtype() {
DType::U8 => (
tensor.to_vec0::<u8>()?.to_le_bytes().to_vec(),
DataType::Uint8,
),
DType::U32 => (
tensor.to_vec0::<u32>()?.to_le_bytes().to_vec(),
DataType::Uint32,
),
DType::I64 => (
tensor.to_vec0::<i64>()?.to_le_bytes().to_vec(),
DataType::Int64,
),
DType::F32 => (
tensor.to_vec0::<f32>()?.to_le_bytes().to_vec(),
DataType::Float,
),
DType::F64 => (
tensor.to_vec0::<f64>()?.to_le_bytes().to_vec(),
DataType::Double,
),
_ => panic!("unsupported DType in test"),
};
let tensor = onnx::TensorProto {
data_type: data_type.into(),
dims: tensor.dims().iter().map(|v| *v as i64).collect(),
raw_data: value,
segment: None,
float_data: vec![],
int32_data: vec![],
string_data: vec![],
int64_data: vec![],
name: "".to_string(),
doc_string: "".to_string(),
external_data: vec![],
data_location: 0,
double_data: vec![],
uint64_data: vec![],
};
attribute.push(AttributeProto {
name: "value".to_string(),
ref_attr_name: "value".to_string(),
i: 0,
doc_string: "value".to_string(),
r#type: AttributeType::Tensor.into(),
f: 0.0,
s: vec![],
t: Some(tensor),
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
})
}
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "ConstantOfShape".to_string(),
domain: "".to_string(),
attribute,
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), Tensor::new(input, &Device::Cpu)?);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval
.get(OUTPUT_Z)
.expect("Output 'z' not found")
.to_dtype(DType::F64)?;
let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?;
match expected.dims().len() {
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
_ => unreachable!(),
};
Ok(())
}
Ok(())
}
// "Unsqueeze"
// #[test]
@ -1639,3 +1773,450 @@ fn test_reduce_mean() -> Result<()> {
Ok(())
}
// "Sqrt"
#[test]
fn test_sqrt() -> Result<()> {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-155
test(&[1., 4., 9.], &[1., 2., 3.])?;
fn test(data: impl NdArray, expected: impl NdArray) -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Sqrt".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let expected = Tensor::new(expected, &Device::Cpu)?;
match expected.dims().len() {
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
_ => unreachable!(),
};
Ok(())
}
Ok(())
}
// "Range"
#[test]
fn test_range() -> Result<()> {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-113
test(1., 5., 2., &[1., 3.])?;
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-113
test(10i64, 6i64, -3i64, &[10i64, 7i64])?;
fn test(
start: impl NdArray,
limit: impl NdArray,
delta: impl NdArray,
expected: impl NdArray,
) -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Range".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![
INPUT_X.to_string(),
INPUT_Y.to_string(),
INPUT_A.to_string(),
],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), Tensor::new(start, &Device::Cpu)?);
inputs.insert(INPUT_Y.to_string(), Tensor::new(limit, &Device::Cpu)?);
inputs.insert(INPUT_A.to_string(), Tensor::new(delta, &Device::Cpu)?);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval
.get(OUTPUT_Z)
.expect("Output 'z' not found")
.to_dtype(DType::F64)?;
let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?;
match expected.dims().len() {
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
_ => unreachable!(),
};
Ok(())
}
Ok(())
}
// "Greater"
#[test]
fn test_greater() -> Result<()> {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-63
test(&[1., 2., 3.], &[3., 2., 1.], &[0u8, 0, 1])?;
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-63
test(&[1., 2., 3.], 2., &[0u8, 0, 1])?;
fn test(a: impl NdArray, b: impl NdArray, expected: impl NdArray) -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Greater".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), Tensor::new(a, &Device::Cpu)?);
inputs.insert(INPUT_Y.to_string(), Tensor::new(b, &Device::Cpu)?);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval
.get(OUTPUT_Z)
.expect("Output 'z' not found")
.to_dtype(DType::F64)?;
let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?;
match expected.dims().len() {
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
_ => unreachable!(),
};
Ok(())
}
Ok(())
}
// "Less"
#[test]
fn test_less() -> Result<()> {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-81
test(&[1., 2., 3.], &[3., 2., 1.], &[1u8, 0, 0])?;
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-81
test(&[1., 2., 3.], 2., &[1u8, 0, 0])?;
fn test(a: impl NdArray, b: impl NdArray, expected: impl NdArray) -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Less".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), Tensor::new(a, &Device::Cpu)?);
inputs.insert(INPUT_Y.to_string(), Tensor::new(b, &Device::Cpu)?);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval
.get(OUTPUT_Z)
.expect("Output 'z' not found")
.to_dtype(DType::F64)?;
let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?;
match expected.dims().len() {
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
_ => unreachable!(),
};
Ok(())
}
Ok(())
}
// "Log"
#[test]
fn test_log() -> Result<()> {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-82
test(&[1., 10.], &[0., std::f64::consts::LN_10])?;
fn test(data: impl NdArray, expected: impl NdArray) -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Log".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let expected = Tensor::new(expected, &Device::Cpu)?;
match expected.dims().len() {
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
_ => unreachable!(),
};
Ok(())
}
Ok(())
}
// "Min"
#[test]
fn test_min() -> Result<()> {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-94
test(&[3., 2., 1.], &[1., 4., 4.], &[2., 5., 0.], &[1., 2., 0.])?;
fn test(
a: impl NdArray,
b: impl NdArray,
c: impl NdArray,
expected: impl NdArray,
) -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Min".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![
INPUT_X.to_string(),
INPUT_Y.to_string(),
INPUT_A.to_string(),
],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), Tensor::new(a, &Device::Cpu)?);
inputs.insert(INPUT_Y.to_string(), Tensor::new(b, &Device::Cpu)?);
inputs.insert(INPUT_A.to_string(), Tensor::new(c, &Device::Cpu)?);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let expected = Tensor::new(expected, &Device::Cpu)?;
match expected.dims().len() {
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
_ => unreachable!(),
};
Ok(())
}
Ok(())
}
// "Where"
#[test]
fn test_where() -> Result<()> {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-173
test(
&[[1u8, 0], [1, 1]],
&[[1i64, 2], [3, 4]],
&[[9i64, 8], [7, 6]],
&[[1i64, 8], [3, 4]],
)?;
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-173
test(
&[[1u8, 0], [1, 1]],
&[[1., 2.], [3., 4.]],
&[[9., 8.], [7., 6.]],
&[[1., 8.], [3., 4.]],
)?;
fn test(
condition: impl NdArray,
x: impl NdArray,
y: impl NdArray,
expected: impl NdArray,
) -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Where".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![
INPUT_X.to_string(),
INPUT_Y.to_string(),
INPUT_A.to_string(),
],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), Tensor::new(condition, &Device::Cpu)?);
inputs.insert(INPUT_Y.to_string(), Tensor::new(x, &Device::Cpu)?);
inputs.insert(INPUT_A.to_string(), Tensor::new(y, &Device::Cpu)?);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval
.get(OUTPUT_Z)
.expect("Output 'z' not found")
.to_dtype(DType::F64)?;
let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?;
match expected.dims().len() {
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
_ => unreachable!(),
};
Ok(())
}
Ok(())
}