#[cfg(feature = "mkl")] extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; use candle::{DType, Device, NdArray, Result, Tensor}; use candle_onnx::onnx; use candle_onnx::onnx::attribute_proto::AttributeType; use candle_onnx::onnx::tensor_proto::DataType; use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto}; use std::collections::HashMap; const INPUT_X: &str = "x"; const INPUT_Y: &str = "y"; const INPUT_A: &str = "a"; const OUTPUT_Z: &str = "z"; fn create_model_proto_with_graph(graph: Option) -> ModelProto { ModelProto { metadata_props: vec![], training_info: vec![], functions: vec![], ir_version: 0, opset_import: vec![], producer_name: "".to_string(), producer_version: "".to_string(), domain: "".to_string(), model_version: 0, doc_string: "".to_string(), graph, } } #[test] fn test_evaluation_fails_without_defined_graph() -> Result<()> { let manual_graph = create_model_proto_with_graph(None); let inputs: HashMap = HashMap::new(); match candle_onnx::simple_eval(&manual_graph, inputs) { Err(err) => assert_eq!(err.to_string(), "no graph defined in proto"), Ok(_) => panic!("Expected an error due to undefined graph"), } Ok(()) } // "Add" #[test] fn test_add_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Add".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string(), INPUT_Y.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?); inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let first = z .to_vec1::()? .to_vec() .get(0) .expect("Failed to get first element") .clone(); assert_eq!(first, 4.0f64); Ok(()) } // "Sub" #[test] fn test_sub_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Sub".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string(), INPUT_Y.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?); inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let first = z .to_vec1::()? .to_vec() .get(0) .expect("Failed to get first element") .clone(); assert_eq!(first, 0.0f64); Ok(()) } // "Mul" #[test] fn test_mul_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Mul".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string(), INPUT_Y.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?); inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let first = z .to_vec1::()? .to_vec() .get(0) .expect("Failed to get first element") .clone(); assert_eq!(first, 4.0f64); Ok(()) } // "Div" #[test] fn test_div_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Div".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string(), INPUT_Y.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?); inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let first = z .to_vec1::()? .to_vec() .get(0) .expect("Failed to get first element") .clone(); assert_eq!(first, 1.0f64); Ok(()) } // "Exp" #[test] fn test_exp_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Exp".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let x = Tensor::from_vec( vec![-1.0f32, 0.0f32, 1.0f32, 2.0f32], &[2, 2], &Device::Cpu, )?; let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec2::()?; assert_eq!(results[0][0], 0.36787944f32); assert_eq!(results[0][1], 1.0f32); assert_eq!(results[1], vec![std::f32::consts::E, 7.38905609f32]); Ok(()) } // "Equal" #[test] fn test_equal_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Equal".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string(), INPUT_Y.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?); inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let first = z.to_dtype(candle::DType::U8)?.to_vec1::()?.to_vec()[0]; assert_eq!(first, 1); Ok(()) } // "Not" #[test] fn test_not_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Not".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(&[0.], &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let first = z.to_dtype(candle::DType::U8)?.to_vec1::()?.to_vec()[0]; assert_eq!(first, 1); Ok(()) } // "MatMul" #[test] fn test_matmul_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "MatMul".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string(), INPUT_Y.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert( INPUT_X.to_string(), Tensor::from_vec( // vec![1.0f32, 2.0f32, 3.0f32, 4.0f32], &[2, 2], &Device::Cpu, )?, ); inputs.insert( INPUT_Y.to_string(), Tensor::from_vec( // vec![5.0f32, 6.0f32, 7.0f32, 8.0f32], &[2, 2], &Device::Cpu, )?, ); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec2::()?; assert_eq!(results, vec![vec![19.0, 22.0], vec![43.0, 50.0]]); Ok(()) } // "Reshape" #[test] fn test_reshape_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Reshape".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string(), INPUT_Y.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![ ValueInfoProto { name: INPUT_X.to_string(), doc_string: "".to_string(), r#type: None, }, ValueInfoProto { name: INPUT_Y.to_string(), doc_string: "".to_string(), r#type: None, }, ], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let x = Tensor::from_vec( // vec![1.0f32, 2.0f32, 3.0f32, 4.0f32], &[2, 2], &Device::Cpu, )?; let y = Tensor::from_vec( // vec![4i64], &[1], &Device::Cpu, )?; let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); inputs.insert(INPUT_Y.to_string(), y); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec1::()?; assert_eq!(results, vec![1.0, 2.0, 3.0, 4.0]); Ok(()) } // "LogSoftmax" #[test] fn test_logsoftmax_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "LogSoftmax".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![ ValueInfoProto { name: INPUT_X.to_string(), doc_string: "".to_string(), r#type: None, }, ValueInfoProto { name: INPUT_Y.to_string(), doc_string: "".to_string(), r#type: None, }, ], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let x = Tensor::from_vec( // vec![1.0f32, 2.0f32, 3.0f32, 4.0f32], &[2, 2], &Device::Cpu, )?; let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec2::()?; assert_eq!( results, vec![vec![0.26894143, 0.7310586], vec![0.26894143, 0.7310586]] ); Ok(()) } // "Softmax" #[test] fn test_softmax_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Softmax".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![ ValueInfoProto { name: INPUT_X.to_string(), doc_string: "".to_string(), r#type: None, }, ValueInfoProto { name: INPUT_Y.to_string(), doc_string: "".to_string(), r#type: None, }, ], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let x = Tensor::from_vec( // vec![1.0f32, 2.0f32, 3.0f32, 4.0f32], &[2, 2], &Device::Cpu, )?; let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec2::()?; assert_eq!( results, vec![vec![0.26894143, 0.7310586], vec![0.26894143, 0.7310586]] ); Ok(()) } // "Transpose" #[test] fn test_transpose_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Transpose".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![ ValueInfoProto { name: INPUT_X.to_string(), doc_string: "".to_string(), r#type: None, }, ValueInfoProto { name: INPUT_Y.to_string(), doc_string: "".to_string(), r#type: None, }, ], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let x = Tensor::from_vec( // vec![1.0f32, 2.0f32, 3.0f32, 4.0f32], &[2, 2], &Device::Cpu, )?; let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec2::()?; assert_eq!(results, vec![vec![1.0, 3.0], vec![2.0, 4.0]]); Ok(()) } // "Dropout" #[test] fn test_dropout_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Dropout".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![ ValueInfoProto { name: INPUT_X.to_string(), doc_string: "".to_string(), r#type: None, }, ValueInfoProto { name: INPUT_Y.to_string(), doc_string: "".to_string(), r#type: None, }, ], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let x = Tensor::from_vec( // vec![1.0f32, 2.0f32, 3.0f32, 4.0f32], &[2, 2], &Device::Cpu, )?; let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec2::()?; assert_eq!(results, vec![vec![1.0, 2.0], vec![3.0, 4.0]]); Ok(()) } // "Flatten" #[test] fn test_flatten_operation() -> Result<()> { let mut att_axis = AttributeProto { name: "axis".to_string(), ref_attr_name: "axis".to_string(), i: 0, doc_string: "axis".to_string(), r#type: 2, f: 0.0, s: vec![], t: None, g: None, sparse_tensor: None, tp: None, floats: vec![], ints: vec![], strings: vec![], tensors: vec![], graphs: vec![], sparse_tensors: vec![], type_protos: vec![], }; let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Flatten".to_string(), domain: "".to_string(), attribute: vec![att_axis.clone()], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![ ValueInfoProto { name: INPUT_X.to_string(), doc_string: "".to_string(), r#type: None, }, ValueInfoProto { name: INPUT_Y.to_string(), doc_string: "".to_string(), r#type: None, }, ], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let x = Tensor::from_vec( vec![ 1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32, 6.0f32, 7.0f32, 8.0f32, ], &[2, 2, 2], &Device::Cpu, )?; let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); let eval = candle_onnx::simple_eval(&manual_graph, inputs.clone())?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec2::()?; assert_eq!(results, vec![vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]]); att_axis.i = 1; let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Flatten".to_string(), domain: "".to_string(), attribute: vec![att_axis.clone()], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![ ValueInfoProto { name: INPUT_X.to_string(), doc_string: "".to_string(), r#type: None, }, ValueInfoProto { name: INPUT_Y.to_string(), doc_string: "".to_string(), r#type: None, }, ], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec2::()?; assert_eq!( results, vec![vec![1.0, 2.0, 3.0, 4.0], vec![5.0, 6.0, 7.0, 8.0]] ); Ok(()) } // Below are ops that are implemented but not tested yet // "MaxPool" // #[test] // "AveragePool" // #[test] // "BatchNormalization" // #[test] // "Squeeze" // #[test] // "ConstantOfShape" #[test] fn test_constant_of_shape() -> Result<()> { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-31 test(&[4i64, 3, 2], Some(1.), &[1., 1., 1.])?; // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-31 test(&[0.], Some(0i64), &[0i64])?; // "value" defaults to 0 f32 test(&[1i64, 2, 3, 4], None as Option, &[0., 0., 0., 0.])?; fn test( input: impl NdArray, value: Option, expected: impl NdArray, ) -> Result<()> { let mut attribute = vec![]; if let Some(value) = value { let tensor = Tensor::new(value, &Device::Cpu)?; let (value, data_type) = match tensor.dtype() { DType::U8 => ( tensor.to_vec0::()?.to_le_bytes().to_vec(), DataType::Uint8, ), DType::U32 => ( tensor.to_vec0::()?.to_le_bytes().to_vec(), DataType::Uint32, ), DType::I64 => ( tensor.to_vec0::()?.to_le_bytes().to_vec(), DataType::Int64, ), DType::F32 => ( tensor.to_vec0::()?.to_le_bytes().to_vec(), DataType::Float, ), DType::F64 => ( tensor.to_vec0::()?.to_le_bytes().to_vec(), DataType::Double, ), _ => panic!("unsupported DType in test"), }; let tensor = onnx::TensorProto { data_type: data_type.into(), dims: tensor.dims().iter().map(|v| *v as i64).collect(), raw_data: value, segment: None, float_data: vec![], int32_data: vec![], string_data: vec![], int64_data: vec![], name: "".to_string(), doc_string: "".to_string(), external_data: vec![], data_location: 0, double_data: vec![], uint64_data: vec![], }; attribute.push(AttributeProto { name: "value".to_string(), ref_attr_name: "value".to_string(), i: 0, doc_string: "value".to_string(), r#type: AttributeType::Tensor.into(), f: 0.0, s: vec![], t: Some(tensor), g: None, sparse_tensor: None, tp: None, floats: vec![], ints: vec![], strings: vec![], tensors: vec![], graphs: vec![], sparse_tensors: vec![], type_protos: vec![], }) } let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "ConstantOfShape".to_string(), domain: "".to_string(), attribute, input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(input, &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval .get(OUTPUT_Z) .expect("Output 'z' not found") .to_dtype(DType::F64)?; let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?; match expected.dims().len() { 0 => assert_eq!(z.to_vec0::()?, expected.to_vec0::()?), 1 => assert_eq!(z.to_vec1::()?, expected.to_vec1::()?), 2 => assert_eq!(z.to_vec2::()?, expected.to_vec2::()?), 3 => assert_eq!(z.to_vec3::()?, expected.to_vec3::()?), _ => unreachable!(), }; Ok(()) } Ok(()) } // "Unsqueeze" // #[test] // "Clip" // #[test] // "Gather" #[test] fn test_gather_operation() -> Result<()> { // test taken from https://onnx.ai/onnx/operators/onnx__Gather.html#summary. test( &[[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]], &[[0i64, 1], [1, 2]], 0, &[[[1.0, 1.2], [2.3, 3.4]], [[2.3, 3.4], [4.5, 5.7]]], )?; // test taken from https://onnx.ai/onnx/operators/onnx__Gather.html#summary. test( &[[1.0, 1.2, 1.9], [2.3, 3.4, 3.9], [4.5, 5.7, 5.9]], &[[0i64, 2]], 1, &[[[1.0, 1.9]], [[2.3, 3.9]], [[4.5, 5.9]]], )?; // all the tests below are generated from numpy.take, which works like // onnx's Gather operation. test(&[1.0, 2.0, 3.0, 4.0], 3i64, 0, 4.0)?; test(&[[1.0, 2.0, 3.0, 4.0]], 3i64, 1, &[4.0])?; test( &[[1.0], [2.0], [3.0], [4.0]], &[3i64, 2], 0, &[[4.0], [3.0]], )?; test( &[ [[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]], [[9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0]], ], 1i64, 0, &[[5.0, 6.0], [7.0, 8.0]], )?; test( &[ [[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]], [[9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0]], ], &[1i64, 0], 0, &[[[5.0, 6.0], [7.0, 8.0]], [[1.0, 2.0], [3.0, 4.0]]], )?; fn test( data: impl NdArray, indices: impl NdArray, axis: i64, expected: impl NdArray, ) -> Result<()> { let att_axis = AttributeProto { name: "axis".to_string(), ref_attr_name: "axis".to_string(), i: axis, doc_string: "axis".to_string(), r#type: 2, f: 0.0, s: vec![], t: None, g: None, sparse_tensor: None, tp: None, floats: vec![], ints: vec![], strings: vec![], tensors: vec![], graphs: vec![], sparse_tensors: vec![], type_protos: vec![], }; let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Gather".to_string(), domain: "".to_string(), attribute: vec![att_axis], input: vec![INPUT_X.to_string(), INPUT_Y.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?); inputs.insert(INPUT_Y.to_string(), Tensor::new(indices, &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let expected = Tensor::new(expected, &Device::Cpu)?; match expected.dims().len() { 0 => assert_eq!(z.to_vec0::()?, expected.to_vec0::()?), 1 => assert_eq!(z.to_vec1::()?, expected.to_vec1::()?), 2 => assert_eq!(z.to_vec2::()?, expected.to_vec2::()?), 3 => assert_eq!(z.to_vec3::()?, expected.to_vec3::()?), _ => unreachable!(), }; Ok(()) } Ok(()) } // "Shape" #[test] fn test_shape_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Shape".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![ValueInfoProto { name: INPUT_X.to_string(), doc_string: "".to_string(), r#type: None, }], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let x = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32, 4.0f32], &[2, 2], &Device::Cpu)?; let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec1::()?; assert_eq!(results, vec![2, 2]); Ok(()) } // "Conv" // #[test] // "Concat" // #[test] // "Abs" #[test] fn test_abs_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Abs".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![ ValueInfoProto { name: INPUT_X.to_string(), doc_string: "".to_string(), r#type: None, }, ValueInfoProto { name: INPUT_Y.to_string(), doc_string: "".to_string(), r#type: None, }, ], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let x = Tensor::from_vec( vec![-1.0f32, 2.0f32, -3.0f32, 4.0f32], &[2, 2], &Device::Cpu, )?; let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec2::()?; assert_eq!(results, vec![vec![1.0, 2.0], vec![3.0, 4.0]]); Ok(()) } // "Cos" #[test] fn test_cos_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Cos".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![ ValueInfoProto { name: INPUT_X.to_string(), doc_string: "".to_string(), r#type: None, }, ValueInfoProto { name: INPUT_Y.to_string(), doc_string: "".to_string(), r#type: None, }, ], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?; let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec2::()?; assert_eq!( results, vec![vec![1.0, 0.54030234], vec![-0.41614684, -0.9899925]] ); Ok(()) } // "Sin" #[test] fn test_sin_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Sin".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![ ValueInfoProto { name: INPUT_X.to_string(), doc_string: "".to_string(), r#type: None, }, ValueInfoProto { name: INPUT_Y.to_string(), doc_string: "".to_string(), r#type: None, }, ], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?; let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec2::()?; assert_eq!(results, vec![vec![0.0, 0.841471], vec![0.9092974, 0.14112]]); Ok(()) } // "Neg" #[test] fn test_neg_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Neg".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![ ValueInfoProto { name: INPUT_X.to_string(), doc_string: "".to_string(), r#type: None, }, ValueInfoProto { name: INPUT_Y.to_string(), doc_string: "".to_string(), r#type: None, }, ], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let x = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32, 4.0f32], &[2, 2], &Device::Cpu)?; let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec2::()?; assert_eq!(results, vec![vec![-1.0, -2.0], vec![-3.0, -4.0]]); Ok(()) } // "Erf" // #[test] // "Tanh" #[test] fn test_tanh_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Tanh".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![ ValueInfoProto { name: INPUT_X.to_string(), doc_string: "".to_string(), r#type: None, }, ValueInfoProto { name: INPUT_Y.to_string(), doc_string: "".to_string(), r#type: None, }, ], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?; let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec2::()?; assert_eq!( results, vec![vec![0.0, 0.7615942], vec![0.9640276, 0.9950548]] ); Ok(()) } // "Sigmoid" #[test] fn test_sigmoid_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Sigmoid".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![ ValueInfoProto { name: INPUT_X.to_string(), doc_string: "".to_string(), r#type: None, }, ValueInfoProto { name: INPUT_Y.to_string(), doc_string: "".to_string(), r#type: None, }, ], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?; let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec2::()?; assert_eq!( results, vec![vec![0.5, 0.7310586], vec![0.880797, 0.95257413]] ); Ok(()) } // "Gelu" #[test] fn test_gelu_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Gelu".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![ ValueInfoProto { name: INPUT_X.to_string(), doc_string: "".to_string(), r#type: None, }, ValueInfoProto { name: INPUT_Y.to_string(), doc_string: "".to_string(), r#type: None, }, ], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?; let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec2::()?; assert_eq!( results, vec![vec![0.0, 0.8413448], vec![1.9544997, 2.9959502]] ); Ok(()) } // "Relu" #[test] fn test_relu_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Relu".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![ValueInfoProto { name: INPUT_X.to_string(), doc_string: "".to_string(), r#type: None, }], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let x = Tensor::from_vec( vec![-1.0f32, 1.0f32, -2.0f32, 3.0f32], &[2, 2], &Device::Cpu, )?; let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec2::()?; assert_eq!(results, vec![vec![0.0, 1.0], vec![0.0, 3.0]]); Ok(()) } // "Constant" // #[test] // "Cast" // #[test] // "ReduceMean" #[test] fn test_reduce_mean() -> Result<()> { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-120 default_axes_keepdims test( &[ [[5., 1.], [20., 2.]], [[30., 1.], [40., 2.]], [[55., 1.], [60., 2.]], ], None, 1, &[[[18.25]]], )?; // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-120 do_no_keepdims test( &[ [[5., 1.], [20., 2.]], [[30., 1.], [40., 2.]], [[55., 1.], [60., 2.]], ], Some(vec![1]), 0, &[[12.5, 1.5], [35.0, 1.5], [57.5, 1.5]], )?; // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-120 keepdims test( &[ [[5., 1.], [20., 2.]], [[30., 1.], [40., 2.]], [[55., 1.], [60., 2.]], ], Some(vec![1]), 1, &[[[12.5, 1.5]], [[35.0, 1.5]], [[57.5, 1.5]]], )?; // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-120 negative_axes_keepdims test( &[ [[5., 1.], [20., 2.]], [[30., 1.], [40., 2.]], [[55., 1.], [60., 2.]], ], Some(vec![-2]), 1, &[[[12.5, 1.5]], [[35.0, 1.5]], [[57.5, 1.5]]], )?; // All the test data below was generated based on numpy's np.mean test( &[ [[5., 1.], [20., 2.]], [[30., 1.], [40., 2.]], [[55., 1.], [60., 2.]], ], Some(vec![1, 2]), 0, &[7.0, 18.25, 29.5], )?; test( &[ [[5., 1.], [20., 2.]], [[30., 1.], [40., 2.]], [[55., 1.], [60., 2.]], ], Some(vec![1, 2]), 1, &[[[7.0]], [[18.25]], [[29.5]]], )?; test(&[1., 2., 3.], None, 1, &[2.0])?; fn test( data: impl NdArray, axes: Option>, keepdims: i64, expected: impl NdArray, ) -> Result<()> { let has_axes = axes.is_some(); let att_axes = AttributeProto { name: "axes".to_string(), ref_attr_name: "axes".to_string(), i: 0, doc_string: "axes".to_string(), r#type: 7, f: 0.0, s: vec![], t: None, g: None, sparse_tensor: None, tp: None, floats: vec![], ints: axes.unwrap_or_default(), strings: vec![], tensors: vec![], graphs: vec![], sparse_tensors: vec![], type_protos: vec![], }; let att_keepdims = AttributeProto { name: "keepdims".to_string(), ref_attr_name: "keepdims".to_string(), i: keepdims, doc_string: "keepdims".to_string(), r#type: 2, f: 0.0, s: vec![], t: None, g: None, sparse_tensor: None, tp: None, floats: vec![], ints: vec![], strings: vec![], tensors: vec![], graphs: vec![], sparse_tensors: vec![], type_protos: vec![], }; let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "ReduceMean".to_string(), domain: "".to_string(), attribute: if has_axes { vec![att_axes, att_keepdims] } else { vec![att_keepdims] }, input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let expected = Tensor::new(expected, &Device::Cpu)?; match expected.dims().len() { 0 => assert_eq!(z.to_vec0::()?, expected.to_vec0::()?), 1 => assert_eq!(z.to_vec1::()?, expected.to_vec1::()?), 2 => assert_eq!(z.to_vec2::()?, expected.to_vec2::()?), 3 => assert_eq!(z.to_vec3::()?, expected.to_vec3::()?), _ => unreachable!(), }; Ok(()) } Ok(()) } // "Sqrt" #[test] fn test_sqrt() -> Result<()> { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-155 test(&[1., 4., 9.], &[1., 2., 3.])?; fn test(data: impl NdArray, expected: impl NdArray) -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Sqrt".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let expected = Tensor::new(expected, &Device::Cpu)?; match expected.dims().len() { 0 => assert_eq!(z.to_vec0::()?, expected.to_vec0::()?), 1 => assert_eq!(z.to_vec1::()?, expected.to_vec1::()?), 2 => assert_eq!(z.to_vec2::()?, expected.to_vec2::()?), 3 => assert_eq!(z.to_vec3::()?, expected.to_vec3::()?), _ => unreachable!(), }; Ok(()) } Ok(()) } // "RandomUniform" #[test] fn test_random_uniform() -> Result<()> { test(vec![3, 2, 1, 4], None, None)?; test(vec![2, 2, 2, 2], Some(-10.0), None)?; test(vec![2, 2, 2, 2], None, Some(10.0))?; test(vec![1, 2, 3, 4], Some(-10.0), Some(10.0))?; fn test(shape: Vec, low: Option, high: Option) -> Result<()> { let att_low = AttributeProto { name: "low".to_string(), ref_attr_name: "low".to_string(), i: 0, doc_string: "low".to_string(), r#type: 1, // FLOAT f: low.unwrap_or(0.0), s: vec![], t: None, g: None, sparse_tensor: None, tp: None, floats: vec![], ints: vec![], strings: vec![], tensors: vec![], graphs: vec![], sparse_tensors: vec![], type_protos: vec![], }; let att_high = AttributeProto { name: "high".to_string(), ref_attr_name: "high".to_string(), i: 0, doc_string: "high".to_string(), r#type: 1, // FLOAT f: high.unwrap_or(1.0), s: vec![], t: None, g: None, sparse_tensor: None, tp: None, floats: vec![], ints: vec![], strings: vec![], tensors: vec![], graphs: vec![], sparse_tensors: vec![], type_protos: vec![], }; let att_shape = AttributeProto { name: "shape".to_string(), ref_attr_name: "shape".to_string(), i: 0, doc_string: "shape".to_string(), r#type: 7, // INTS f: 0.0, s: vec![], t: None, g: None, sparse_tensor: None, tp: None, floats: vec![], ints: shape, strings: vec![], tensors: vec![], graphs: vec![], sparse_tensors: vec![], type_protos: vec![], }; let att_dtype = AttributeProto { name: "dtype".to_string(), ref_attr_name: "dtype".to_string(), i: 11, // DOUBLE doc_string: "dtype".to_string(), r#type: 2, // INT f: 0.0, s: vec![], t: None, g: None, sparse_tensor: None, tp: None, floats: vec![], ints: vec![], strings: vec![], tensors: vec![], graphs: vec![], sparse_tensors: vec![], type_protos: vec![], }; let attrs = { let mut mut_attrs = vec![att_shape, att_dtype]; if low.is_some() { mut_attrs.push(att_low); } if high.is_some() { mut_attrs.push(att_high); } mut_attrs }; let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "RandomUniform".to_string(), domain: "".to_string(), attribute: attrs, input: vec![], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let eval = candle_onnx::simple_eval(&manual_graph, HashMap::new())?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let min = z .flatten_all()? .to_vec1()? .into_iter() .reduce(f64::min) .unwrap(); let max = z .flatten_all()? .to_vec1()? .into_iter() .reduce(f64::max) .unwrap(); assert!(min >= low.unwrap_or(0.0).into()); assert!(max <= high.unwrap_or(1.0).into()); assert_ne!(min, max); Ok(()) } Ok(()) } // "Range" #[test] fn test_range() -> Result<()> { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-113 test(1., 5., 2., &[1., 3.])?; // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-113 test(10i64, 6i64, -3i64, &[10i64, 7i64])?; fn test( start: impl NdArray, limit: impl NdArray, delta: impl NdArray, expected: impl NdArray, ) -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Range".to_string(), domain: "".to_string(), attribute: vec![], input: vec![ INPUT_X.to_string(), INPUT_Y.to_string(), INPUT_A.to_string(), ], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(start, &Device::Cpu)?); inputs.insert(INPUT_Y.to_string(), Tensor::new(limit, &Device::Cpu)?); inputs.insert(INPUT_A.to_string(), Tensor::new(delta, &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval .get(OUTPUT_Z) .expect("Output 'z' not found") .to_dtype(DType::F64)?; let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?; match expected.dims().len() { 0 => assert_eq!(z.to_vec0::()?, expected.to_vec0::()?), 1 => assert_eq!(z.to_vec1::()?, expected.to_vec1::()?), 2 => assert_eq!(z.to_vec2::()?, expected.to_vec2::()?), 3 => assert_eq!(z.to_vec3::()?, expected.to_vec3::()?), _ => unreachable!(), }; Ok(()) } Ok(()) } // "Greater" #[test] fn test_greater() -> Result<()> { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-63 test(&[1., 2., 3.], &[3., 2., 1.], &[0u8, 0, 1])?; // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-63 test(&[1., 2., 3.], 2., &[0u8, 0, 1])?; fn test(a: impl NdArray, b: impl NdArray, expected: impl NdArray) -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Greater".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string(), INPUT_Y.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(a, &Device::Cpu)?); inputs.insert(INPUT_Y.to_string(), Tensor::new(b, &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval .get(OUTPUT_Z) .expect("Output 'z' not found") .to_dtype(DType::F64)?; let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?; match expected.dims().len() { 0 => assert_eq!(z.to_vec0::()?, expected.to_vec0::()?), 1 => assert_eq!(z.to_vec1::()?, expected.to_vec1::()?), 2 => assert_eq!(z.to_vec2::()?, expected.to_vec2::()?), 3 => assert_eq!(z.to_vec3::()?, expected.to_vec3::()?), _ => unreachable!(), }; Ok(()) } Ok(()) } // "Less" #[test] fn test_less() -> Result<()> { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-81 test(&[1., 2., 3.], &[3., 2., 1.], &[1u8, 0, 0])?; // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-81 test(&[1., 2., 3.], 2., &[1u8, 0, 0])?; fn test(a: impl NdArray, b: impl NdArray, expected: impl NdArray) -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Less".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string(), INPUT_Y.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(a, &Device::Cpu)?); inputs.insert(INPUT_Y.to_string(), Tensor::new(b, &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval .get(OUTPUT_Z) .expect("Output 'z' not found") .to_dtype(DType::F64)?; let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?; match expected.dims().len() { 0 => assert_eq!(z.to_vec0::()?, expected.to_vec0::()?), 1 => assert_eq!(z.to_vec1::()?, expected.to_vec1::()?), 2 => assert_eq!(z.to_vec2::()?, expected.to_vec2::()?), 3 => assert_eq!(z.to_vec3::()?, expected.to_vec3::()?), _ => unreachable!(), }; Ok(()) } Ok(()) } // "Log" #[test] fn test_log() -> Result<()> { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-82 test(&[1., 10.], &[0., std::f64::consts::LN_10])?; fn test(data: impl NdArray, expected: impl NdArray) -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Log".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let expected = Tensor::new(expected, &Device::Cpu)?; match expected.dims().len() { 0 => assert_eq!(z.to_vec0::()?, expected.to_vec0::()?), 1 => assert_eq!(z.to_vec1::()?, expected.to_vec1::()?), 2 => assert_eq!(z.to_vec2::()?, expected.to_vec2::()?), 3 => assert_eq!(z.to_vec3::()?, expected.to_vec3::()?), _ => unreachable!(), }; Ok(()) } Ok(()) } // "Min" #[test] fn test_min() -> Result<()> { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-94 test(&[3., 2., 1.], &[1., 4., 4.], &[2., 5., 0.], &[1., 2., 0.])?; fn test( a: impl NdArray, b: impl NdArray, c: impl NdArray, expected: impl NdArray, ) -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Min".to_string(), domain: "".to_string(), attribute: vec![], input: vec![ INPUT_X.to_string(), INPUT_Y.to_string(), INPUT_A.to_string(), ], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(a, &Device::Cpu)?); inputs.insert(INPUT_Y.to_string(), Tensor::new(b, &Device::Cpu)?); inputs.insert(INPUT_A.to_string(), Tensor::new(c, &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let expected = Tensor::new(expected, &Device::Cpu)?; match expected.dims().len() { 0 => assert_eq!(z.to_vec0::()?, expected.to_vec0::()?), 1 => assert_eq!(z.to_vec1::()?, expected.to_vec1::()?), 2 => assert_eq!(z.to_vec2::()?, expected.to_vec2::()?), 3 => assert_eq!(z.to_vec3::()?, expected.to_vec3::()?), _ => unreachable!(), }; Ok(()) } Ok(()) } // "Where" #[test] fn test_where() -> Result<()> { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-173 test( &[[1u8, 0], [1, 1]], &[[1i64, 2], [3, 4]], &[[9i64, 8], [7, 6]], &[[1i64, 8], [3, 4]], )?; // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-173 test( &[[1u8, 0], [1, 1]], &[[1., 2.], [3., 4.]], &[[9., 8.], [7., 6.]], &[[1., 8.], [3., 4.]], )?; fn test( condition: impl NdArray, x: impl NdArray, y: impl NdArray, expected: impl NdArray, ) -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Where".to_string(), domain: "".to_string(), attribute: vec![], input: vec![ INPUT_X.to_string(), INPUT_Y.to_string(), INPUT_A.to_string(), ], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(condition, &Device::Cpu)?); inputs.insert(INPUT_Y.to_string(), Tensor::new(x, &Device::Cpu)?); inputs.insert(INPUT_A.to_string(), Tensor::new(y, &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval .get(OUTPUT_Z) .expect("Output 'z' not found") .to_dtype(DType::F64)?; let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?; match expected.dims().len() { 0 => assert_eq!(z.to_vec0::()?, expected.to_vec0::()?), 1 => assert_eq!(z.to_vec1::()?, expected.to_vec1::()?), 2 => assert_eq!(z.to_vec2::()?, expected.to_vec2::()?), 3 => assert_eq!(z.to_vec3::()?, expected.to_vec3::()?), _ => unreachable!(), }; Ok(()) } Ok(()) }