#[cfg(feature = "mkl")] extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; use candle::{DType, Device, NdArray, Result, Tensor}; use candle_onnx::onnx; use candle_onnx::onnx::attribute_proto::AttributeType; use candle_onnx::onnx::tensor_proto::DataType; use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto}; use std::collections::HashMap; const INPUT_X: &str = "x"; const INPUT_Y: &str = "y"; const INPUT_A: &str = "a"; const OUTPUT_Z: &str = "z"; fn create_model_proto_with_graph(graph: Option) -> ModelProto { ModelProto { metadata_props: vec![], training_info: vec![], functions: vec![], ir_version: 0, opset_import: vec![], producer_name: "".to_string(), producer_version: "".to_string(), domain: "".to_string(), model_version: 0, doc_string: "".to_string(), graph, } } #[test] fn test_evaluation_fails_without_defined_graph() -> Result<()> { let manual_graph = create_model_proto_with_graph(None); let inputs: HashMap = HashMap::new(); match candle_onnx::simple_eval(&manual_graph, inputs) { Err(err) => assert_eq!(err.to_string(), "no graph defined in proto"), Ok(_) => panic!("Expected an error due to undefined graph"), } Ok(()) } // "Add" #[test] fn test_add_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Add".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string(), INPUT_Y.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?); inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let first = z .to_vec1::()? .to_vec() .get(0) .expect("Failed to get first element") .clone(); assert_eq!(first, 4.0f64); Ok(()) } // "Sub" #[test] fn test_sub_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Sub".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string(), INPUT_Y.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?); inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let first = z .to_vec1::()? .to_vec() .get(0) .expect("Failed to get first element") .clone(); assert_eq!(first, 0.0f64); Ok(()) } // "Mul" #[test] fn test_mul_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Mul".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string(), INPUT_Y.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?); inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let first = z .to_vec1::()? .to_vec() .get(0) .expect("Failed to get first element") .clone(); assert_eq!(first, 4.0f64); Ok(()) } // "Div" #[test] fn test_div_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Div".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string(), INPUT_Y.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?); inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let first = z .to_vec1::()? .to_vec() .get(0) .expect("Failed to get first element") .clone(); assert_eq!(first, 1.0f64); Ok(()) } // "Exp" #[test] fn test_exp_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Exp".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let x = Tensor::from_vec(vec![-1.0f32, 0.0f32, 1.0f32, 2.0f32], &[2, 2], &Device::Cpu)?; let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec2::()?; assert_eq!(results[0][0], 0.36787944f32); assert_eq!(results[0][1], 1.0f32); assert_eq!(results[1], vec![std::f32::consts::E, 7.38905609f32]); Ok(()) } // "Equal" #[test] fn test_equal_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Equal".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string(), INPUT_Y.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(&[2.], &Device::Cpu)?); inputs.insert(INPUT_Y.to_string(), Tensor::new(&[2.], &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let first = z.to_dtype(candle::DType::U8)?.to_vec1::()?.to_vec()[0]; assert_eq!(first, 1); Ok(()) } // "Not" #[test] fn test_not_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Not".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(&[0.], &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let first = z.to_dtype(candle::DType::U8)?.to_vec1::()?.to_vec()[0]; assert_eq!(first, 1); Ok(()) } // "MatMul" #[test] fn test_matmul_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "MatMul".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string(), INPUT_Y.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert( INPUT_X.to_string(), Tensor::from_vec( // vec![1.0f32, 2.0f32, 3.0f32, 4.0f32], &[2, 2], &Device::Cpu, )?, ); inputs.insert( INPUT_Y.to_string(), Tensor::from_vec( // vec![5.0f32, 6.0f32, 7.0f32, 8.0f32], &[2, 2], &Device::Cpu, )?, ); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec2::()?; assert_eq!(results, vec![vec![19.0, 22.0], vec![43.0, 50.0]]); Ok(()) } // "Reshape" #[test] fn test_reshape_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Reshape".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string(), INPUT_Y.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![ ValueInfoProto { name: INPUT_X.to_string(), doc_string: "".to_string(), r#type: None, }, ValueInfoProto { name: INPUT_Y.to_string(), doc_string: "".to_string(), r#type: None, }, ], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let x = Tensor::from_vec( // vec![1.0f32, 2.0f32, 3.0f32, 4.0f32], &[2, 2], &Device::Cpu, )?; let y = Tensor::from_vec( // vec![4i64], &[1], &Device::Cpu, )?; let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); inputs.insert(INPUT_Y.to_string(), y); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec1::()?; assert_eq!(results, vec![1.0, 2.0, 3.0, 4.0]); Ok(()) } // "LogSoftmax" #[test] fn test_logsoftmax_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "LogSoftmax".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![ ValueInfoProto { name: INPUT_X.to_string(), doc_string: "".to_string(), r#type: None, }, ValueInfoProto { name: INPUT_Y.to_string(), doc_string: "".to_string(), r#type: None, }, ], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let x = Tensor::from_vec( // vec![1.0f32, 2.0f32, 3.0f32, 4.0f32], &[2, 2], &Device::Cpu, )?; let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec2::()?; assert_eq!( results, vec![vec![0.26894143, 0.7310586], vec![0.26894143, 0.7310586]] ); Ok(()) } // "Softmax" #[test] fn test_softmax_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Softmax".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![ ValueInfoProto { name: INPUT_X.to_string(), doc_string: "".to_string(), r#type: None, }, ValueInfoProto { name: INPUT_Y.to_string(), doc_string: "".to_string(), r#type: None, }, ], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let x = Tensor::from_vec( // vec![1.0f32, 2.0f32, 3.0f32, 4.0f32], &[2, 2], &Device::Cpu, )?; let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec2::()?; assert_eq!( results, vec![vec![0.26894143, 0.7310586], vec![0.26894143, 0.7310586]] ); Ok(()) } // "Transpose" #[test] fn test_transpose_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Transpose".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![ ValueInfoProto { name: INPUT_X.to_string(), doc_string: "".to_string(), r#type: None, }, ValueInfoProto { name: INPUT_Y.to_string(), doc_string: "".to_string(), r#type: None, }, ], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let x = Tensor::from_vec( // vec![1.0f32, 2.0f32, 3.0f32, 4.0f32], &[2, 2], &Device::Cpu, )?; let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec2::()?; assert_eq!(results, vec![vec![1.0, 3.0], vec![2.0, 4.0]]); Ok(()) } // "Dropout" #[test] fn test_dropout_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Dropout".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![ ValueInfoProto { name: INPUT_X.to_string(), doc_string: "".to_string(), r#type: None, }, ValueInfoProto { name: INPUT_Y.to_string(), doc_string: "".to_string(), r#type: None, }, ], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let x = Tensor::from_vec( // vec![1.0f32, 2.0f32, 3.0f32, 4.0f32], &[2, 2], &Device::Cpu, )?; let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec2::()?; assert_eq!(results, vec![vec![1.0, 2.0], vec![3.0, 4.0]]); Ok(()) } // "Flatten" #[test] fn test_flatten_operation() -> Result<()> { let mut att_axis = AttributeProto { name: "axis".to_string(), ref_attr_name: "axis".to_string(), i: 0, doc_string: "axis".to_string(), r#type: 2, f: 0.0, s: vec![], t: None, g: None, sparse_tensor: None, tp: None, floats: vec![], ints: vec![], strings: vec![], tensors: vec![], graphs: vec![], sparse_tensors: vec![], type_protos: vec![], }; let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Flatten".to_string(), domain: "".to_string(), attribute: vec![att_axis.clone()], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![ ValueInfoProto { name: INPUT_X.to_string(), doc_string: "".to_string(), r#type: None, }, ValueInfoProto { name: INPUT_Y.to_string(), doc_string: "".to_string(), r#type: None, }, ], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let x = Tensor::from_vec( vec![ 1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32, 6.0f32, 7.0f32, 8.0f32, ], &[2, 2, 2], &Device::Cpu, )?; let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); let eval = candle_onnx::simple_eval(&manual_graph, inputs.clone())?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec2::()?; assert_eq!(results, vec![vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]]); att_axis.i = 1; let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Flatten".to_string(), domain: "".to_string(), attribute: vec![att_axis.clone()], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![ ValueInfoProto { name: INPUT_X.to_string(), doc_string: "".to_string(), r#type: None, }, ValueInfoProto { name: INPUT_Y.to_string(), doc_string: "".to_string(), r#type: None, }, ], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec2::()?; assert_eq!( results, vec![vec![1.0, 2.0, 3.0, 4.0], vec![5.0, 6.0, 7.0, 8.0]] ); Ok(()) } // Below are ops that are implemented but not tested yet // "MaxPool" // #[test] // "AveragePool" // #[test] // "BatchNormalization" // #[test] // "Squeeze" // #[test] // "ConstantOfShape" #[test] fn test_constant_of_shape() -> Result<()> { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-31 test(&[4i64, 3, 2], Some(1.), &[1., 1., 1.])?; // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-31 test(&[0.], Some(0i64), &[0i64])?; // "value" defaults to 0 f32 test(&[1i64, 2, 3, 4], None as Option, &[0., 0., 0., 0.])?; fn test( input: impl NdArray, value: Option, expected: impl NdArray, ) -> Result<()> { let mut attribute = vec![]; if let Some(value) = value { let tensor = Tensor::new(value, &Device::Cpu)?; let (value, data_type) = match tensor.dtype() { DType::U8 => ( tensor.to_vec0::()?.to_le_bytes().to_vec(), DataType::Uint8, ), DType::U32 => ( tensor.to_vec0::()?.to_le_bytes().to_vec(), DataType::Uint32, ), DType::I64 => ( tensor.to_vec0::()?.to_le_bytes().to_vec(), DataType::Int64, ), DType::F32 => ( tensor.to_vec0::()?.to_le_bytes().to_vec(), DataType::Float, ), DType::F64 => ( tensor.to_vec0::()?.to_le_bytes().to_vec(), DataType::Double, ), _ => panic!("unsupported DType in test"), }; let tensor = onnx::TensorProto { data_type: data_type.into(), dims: tensor.dims().iter().map(|v| *v as i64).collect(), raw_data: value, segment: None, float_data: vec![], int32_data: vec![], string_data: vec![], int64_data: vec![], name: "".to_string(), doc_string: "".to_string(), external_data: vec![], data_location: 0, double_data: vec![], uint64_data: vec![], }; attribute.push(AttributeProto { name: "value".to_string(), ref_attr_name: "value".to_string(), i: 0, doc_string: "value".to_string(), r#type: AttributeType::Tensor.into(), f: 0.0, s: vec![], t: Some(tensor), g: None, sparse_tensor: None, tp: None, floats: vec![], ints: vec![], strings: vec![], tensors: vec![], graphs: vec![], sparse_tensors: vec![], type_protos: vec![], }) } let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "ConstantOfShape".to_string(), domain: "".to_string(), attribute, input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(input, &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval .get(OUTPUT_Z) .expect("Output 'z' not found") .to_dtype(DType::F64)?; let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?; match expected.dims().len() { 0 => assert_eq!(z.to_vec0::()?, expected.to_vec0::()?), 1 => assert_eq!(z.to_vec1::()?, expected.to_vec1::()?), 2 => assert_eq!(z.to_vec2::()?, expected.to_vec2::()?), 3 => assert_eq!(z.to_vec3::()?, expected.to_vec3::()?), _ => unreachable!(), }; Ok(()) } Ok(()) } // "Unsqueeze" // #[test] // "Clip" // #[test] // "Gather" #[test] fn test_gather_operation() -> Result<()> { // test taken from https://onnx.ai/onnx/operators/onnx__Gather.html#summary. test( &[[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]], &[[0i64, 1], [1, 2]], 0, &[[[1.0, 1.2], [2.3, 3.4]], [[2.3, 3.4], [4.5, 5.7]]], )?; // test taken from https://onnx.ai/onnx/operators/onnx__Gather.html#summary. test( &[[1.0, 1.2, 1.9], [2.3, 3.4, 3.9], [4.5, 5.7, 5.9]], &[[0i64, 2]], 1, &[[[1.0, 1.9]], [[2.3, 3.9]], [[4.5, 5.9]]], )?; // all the tests below are generated from numpy.take, which works like // onnx's Gather operation. test(&[1.0, 2.0, 3.0, 4.0], 3i64, 0, 4.0)?; test(&[[1.0, 2.0, 3.0, 4.0]], 3i64, 1, &[4.0])?; test( &[[1.0], [2.0], [3.0], [4.0]], &[3i64, 2], 0, &[[4.0], [3.0]], )?; test( &[ [[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]], [[9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0]], ], 1i64, 0, &[[5.0, 6.0], [7.0, 8.0]], )?; test( &[ [[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]], [[9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0]], ], &[1i64, 0], 0, &[[[5.0, 6.0], [7.0, 8.0]], [[1.0, 2.0], [3.0, 4.0]]], )?; fn test( data: impl NdArray, indices: impl NdArray, axis: i64, expected: impl NdArray, ) -> Result<()> { let att_axis = AttributeProto { name: "axis".to_string(), ref_attr_name: "axis".to_string(), i: axis, doc_string: "axis".to_string(), r#type: 2, f: 0.0, s: vec![], t: None, g: None, sparse_tensor: None, tp: None, floats: vec![], ints: vec![], strings: vec![], tensors: vec![], graphs: vec![], sparse_tensors: vec![], type_protos: vec![], }; let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Gather".to_string(), domain: "".to_string(), attribute: vec![att_axis], input: vec![INPUT_X.to_string(), INPUT_Y.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?); inputs.insert(INPUT_Y.to_string(), Tensor::new(indices, &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let expected = Tensor::new(expected, &Device::Cpu)?; match expected.dims().len() { 0 => assert_eq!(z.to_vec0::()?, expected.to_vec0::()?), 1 => assert_eq!(z.to_vec1::()?, expected.to_vec1::()?), 2 => assert_eq!(z.to_vec2::()?, expected.to_vec2::()?), 3 => assert_eq!(z.to_vec3::()?, expected.to_vec3::()?), _ => unreachable!(), }; Ok(()) } Ok(()) } // "Shape" #[test] fn test_shape_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Shape".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![ValueInfoProto { name: INPUT_X.to_string(), doc_string: "".to_string(), r#type: None, }], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let x = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32, 4.0f32], &[2, 2], &Device::Cpu)?; let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec1::()?; assert_eq!(results, vec![2, 2]); Ok(()) } // "Conv" // #[test] // "Concat" // #[test] // "Abs" #[test] fn test_abs_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Abs".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![ ValueInfoProto { name: INPUT_X.to_string(), doc_string: "".to_string(), r#type: None, }, ValueInfoProto { name: INPUT_Y.to_string(), doc_string: "".to_string(), r#type: None, }, ], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let x = Tensor::from_vec( vec![-1.0f32, 2.0f32, -3.0f32, 4.0f32], &[2, 2], &Device::Cpu, )?; let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec2::()?; assert_eq!(results, vec![vec![1.0, 2.0], vec![3.0, 4.0]]); Ok(()) } // "Cos" #[test] fn test_cos_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Cos".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![ ValueInfoProto { name: INPUT_X.to_string(), doc_string: "".to_string(), r#type: None, }, ValueInfoProto { name: INPUT_Y.to_string(), doc_string: "".to_string(), r#type: None, }, ], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?; let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec2::()?; assert_eq!( results, vec![vec![1.0, 0.54030234], vec![-0.41614684, -0.9899925]] ); Ok(()) } // "Sin" #[test] fn test_sin_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Sin".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![ ValueInfoProto { name: INPUT_X.to_string(), doc_string: "".to_string(), r#type: None, }, ValueInfoProto { name: INPUT_Y.to_string(), doc_string: "".to_string(), r#type: None, }, ], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?; let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec2::()?; assert_eq!(results, vec![vec![0.0, 0.841471], vec![0.9092974, 0.14112]]); Ok(()) } // "Neg" #[test] fn test_neg_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Neg".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![ ValueInfoProto { name: INPUT_X.to_string(), doc_string: "".to_string(), r#type: None, }, ValueInfoProto { name: INPUT_Y.to_string(), doc_string: "".to_string(), r#type: None, }, ], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let x = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32, 4.0f32], &[2, 2], &Device::Cpu)?; let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec2::()?; assert_eq!(results, vec![vec![-1.0, -2.0], vec![-3.0, -4.0]]); Ok(()) } // "Erf" // #[test] // "Tanh" #[test] fn test_tanh_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Tanh".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![ ValueInfoProto { name: INPUT_X.to_string(), doc_string: "".to_string(), r#type: None, }, ValueInfoProto { name: INPUT_Y.to_string(), doc_string: "".to_string(), r#type: None, }, ], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?; let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec2::()?; assert_eq!( results, vec![vec![0.0, 0.7615942], vec![0.9640276, 0.9950548]] ); Ok(()) } // "Sigmoid" #[test] fn test_sigmoid_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Sigmoid".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![ ValueInfoProto { name: INPUT_X.to_string(), doc_string: "".to_string(), r#type: None, }, ValueInfoProto { name: INPUT_Y.to_string(), doc_string: "".to_string(), r#type: None, }, ], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?; let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec2::()?; assert_eq!( results, vec![vec![0.5, 0.7310586], vec![0.880797, 0.95257413]] ); Ok(()) } // "Gelu" #[test] fn test_gelu_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Gelu".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![ ValueInfoProto { name: INPUT_X.to_string(), doc_string: "".to_string(), r#type: None, }, ValueInfoProto { name: INPUT_Y.to_string(), doc_string: "".to_string(), r#type: None, }, ], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let x = Tensor::from_vec(vec![0.0f32, 1.0f32, 2.0f32, 3.0f32], &[2, 2], &Device::Cpu)?; let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec2::()?; assert_eq!( results, vec![vec![0.0, 0.8413448], vec![1.9544997, 2.9959502]] ); Ok(()) } // "Relu" #[test] fn test_relu_operation() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Relu".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![ValueInfoProto { name: INPUT_X.to_string(), doc_string: "".to_string(), r#type: None, }], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let x = Tensor::from_vec( vec![-1.0f32, 1.0f32, -2.0f32, 3.0f32], &[2, 2], &Device::Cpu, )?; let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec2::()?; assert_eq!(results, vec![vec![0.0, 1.0], vec![0.0, 3.0]]); Ok(()) } // "Constant" // #[test] // "Cast" // #[test] // "ReduceMean" #[test] fn test_reduce_mean() -> Result<()> { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-120 default_axes_keepdims test( &[ [[5., 1.], [20., 2.]], [[30., 1.], [40., 2.]], [[55., 1.], [60., 2.]], ], None, 1, &[[[18.25]]], )?; // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-120 do_no_keepdims test( &[ [[5., 1.], [20., 2.]], [[30., 1.], [40., 2.]], [[55., 1.], [60., 2.]], ], Some(vec![1]), 0, &[[12.5, 1.5], [35.0, 1.5], [57.5, 1.5]], )?; // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-120 keepdims test( &[ [[5., 1.], [20., 2.]], [[30., 1.], [40., 2.]], [[55., 1.], [60., 2.]], ], Some(vec![1]), 1, &[[[12.5, 1.5]], [[35.0, 1.5]], [[57.5, 1.5]]], )?; // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-120 negative_axes_keepdims test( &[ [[5., 1.], [20., 2.]], [[30., 1.], [40., 2.]], [[55., 1.], [60., 2.]], ], Some(vec![-2]), 1, &[[[12.5, 1.5]], [[35.0, 1.5]], [[57.5, 1.5]]], )?; // All the test data below was generated based on numpy's np.mean test( &[ [[5., 1.], [20., 2.]], [[30., 1.], [40., 2.]], [[55., 1.], [60., 2.]], ], Some(vec![1, 2]), 0, &[7.0, 18.25, 29.5], )?; test( &[ [[5., 1.], [20., 2.]], [[30., 1.], [40., 2.]], [[55., 1.], [60., 2.]], ], Some(vec![1, 2]), 1, &[[[7.0]], [[18.25]], [[29.5]]], )?; test(&[1., 2., 3.], None, 1, &[2.0])?; fn test( data: impl NdArray, axes: Option>, keepdims: i64, expected: impl NdArray, ) -> Result<()> { let has_axes = axes.is_some(); let att_axes = AttributeProto { name: "axes".to_string(), ref_attr_name: "axes".to_string(), i: 0, doc_string: "axes".to_string(), r#type: 7, f: 0.0, s: vec![], t: None, g: None, sparse_tensor: None, tp: None, floats: vec![], ints: axes.unwrap_or_default(), strings: vec![], tensors: vec![], graphs: vec![], sparse_tensors: vec![], type_protos: vec![], }; let att_keepdims = AttributeProto { name: "keepdims".to_string(), ref_attr_name: "keepdims".to_string(), i: keepdims, doc_string: "keepdims".to_string(), r#type: 2, f: 0.0, s: vec![], t: None, g: None, sparse_tensor: None, tp: None, floats: vec![], ints: vec![], strings: vec![], tensors: vec![], graphs: vec![], sparse_tensors: vec![], type_protos: vec![], }; let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "ReduceMean".to_string(), domain: "".to_string(), attribute: if has_axes { vec![att_axes, att_keepdims] } else { vec![att_keepdims] }, input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let expected = Tensor::new(expected, &Device::Cpu)?; match expected.dims().len() { 0 => assert_eq!(z.to_vec0::()?, expected.to_vec0::()?), 1 => assert_eq!(z.to_vec1::()?, expected.to_vec1::()?), 2 => assert_eq!(z.to_vec2::()?, expected.to_vec2::()?), 3 => assert_eq!(z.to_vec3::()?, expected.to_vec3::()?), _ => unreachable!(), }; Ok(()) } Ok(()) } // "Sqrt" #[test] fn test_sqrt() -> Result<()> { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-155 test(&[1., 4., 9.], &[1., 2., 3.])?; fn test(data: impl NdArray, expected: impl NdArray) -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Sqrt".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let expected = Tensor::new(expected, &Device::Cpu)?; match expected.dims().len() { 0 => assert_eq!(z.to_vec0::()?, expected.to_vec0::()?), 1 => assert_eq!(z.to_vec1::()?, expected.to_vec1::()?), 2 => assert_eq!(z.to_vec2::()?, expected.to_vec2::()?), 3 => assert_eq!(z.to_vec3::()?, expected.to_vec3::()?), _ => unreachable!(), }; Ok(()) } Ok(()) } // "RandomUniform" #[test] fn test_random_uniform() -> Result<()> { test(vec![3, 2, 1, 4], None, None)?; test(vec![2, 2, 2, 2], Some(-10.0), None)?; test(vec![2, 2, 2, 2], None, Some(10.0))?; test(vec![1, 2, 3, 4], Some(-10.0), Some(10.0))?; fn test(shape: Vec, low: Option, high: Option) -> Result<()> { let att_low = AttributeProto { name: "low".to_string(), ref_attr_name: "low".to_string(), i: 0, doc_string: "low".to_string(), r#type: 1, // FLOAT f: low.unwrap_or(0.0), s: vec![], t: None, g: None, sparse_tensor: None, tp: None, floats: vec![], ints: vec![], strings: vec![], tensors: vec![], graphs: vec![], sparse_tensors: vec![], type_protos: vec![], }; let att_high = AttributeProto { name: "high".to_string(), ref_attr_name: "high".to_string(), i: 0, doc_string: "high".to_string(), r#type: 1, // FLOAT f: high.unwrap_or(1.0), s: vec![], t: None, g: None, sparse_tensor: None, tp: None, floats: vec![], ints: vec![], strings: vec![], tensors: vec![], graphs: vec![], sparse_tensors: vec![], type_protos: vec![], }; let att_shape = AttributeProto { name: "shape".to_string(), ref_attr_name: "shape".to_string(), i: 0, doc_string: "shape".to_string(), r#type: 7, // INTS f: 0.0, s: vec![], t: None, g: None, sparse_tensor: None, tp: None, floats: vec![], ints: shape, strings: vec![], tensors: vec![], graphs: vec![], sparse_tensors: vec![], type_protos: vec![], }; let att_dtype = AttributeProto { name: "dtype".to_string(), ref_attr_name: "dtype".to_string(), i: 11, // DOUBLE doc_string: "dtype".to_string(), r#type: 2, // INT f: 0.0, s: vec![], t: None, g: None, sparse_tensor: None, tp: None, floats: vec![], ints: vec![], strings: vec![], tensors: vec![], graphs: vec![], sparse_tensors: vec![], type_protos: vec![], }; let attrs = { let mut mut_attrs = vec![att_shape, att_dtype]; if low.is_some() { mut_attrs.push(att_low); } if high.is_some() { mut_attrs.push(att_high); } mut_attrs }; let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "RandomUniform".to_string(), domain: "".to_string(), attribute: attrs, input: vec![], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let eval = candle_onnx::simple_eval(&manual_graph, HashMap::new())?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let min = z .flatten_all()? .to_vec1()? .into_iter() .reduce(f64::min) .unwrap(); let max = z .flatten_all()? .to_vec1()? .into_iter() .reduce(f64::max) .unwrap(); assert!(min >= low.unwrap_or(0.0).into()); assert!(max <= high.unwrap_or(1.0).into()); assert_ne!(min, max); Ok(()) } Ok(()) } // "RandomNormal" #[test] fn test_random_normal() -> Result<()> { test(vec![3, 2, 1, 4], None, None)?; test(vec![2, 2, 2, 2], Some(-10.0), None)?; test(vec![2, 2, 2, 2], None, Some(10.0))?; test(vec![1, 2, 3, 4], Some(-10.0), Some(10.0))?; fn test(shape: Vec, mean: Option, scale: Option) -> Result<()> { let att_mean = AttributeProto { name: "mean".to_string(), ref_attr_name: "mean".to_string(), i: 0, doc_string: "mean".to_string(), r#type: 1, // FLOAT f: mean.unwrap_or(0.0), s: vec![], t: None, g: None, sparse_tensor: None, tp: None, floats: vec![], ints: vec![], strings: vec![], tensors: vec![], graphs: vec![], sparse_tensors: vec![], type_protos: vec![], }; let att_scale = AttributeProto { name: "scale".to_string(), ref_attr_name: "scale".to_string(), i: 0, doc_string: "scale".to_string(), r#type: 1, // FLOAT f: scale.unwrap_or(1.0), s: vec![], t: None, g: None, sparse_tensor: None, tp: None, floats: vec![], ints: vec![], strings: vec![], tensors: vec![], graphs: vec![], sparse_tensors: vec![], type_protos: vec![], }; let att_shape = AttributeProto { name: "shape".to_string(), ref_attr_name: "shape".to_string(), i: 0, doc_string: "shape".to_string(), r#type: 7, // INTS f: 0.0, s: vec![], t: None, g: None, sparse_tensor: None, tp: None, floats: vec![], ints: shape, strings: vec![], tensors: vec![], graphs: vec![], sparse_tensors: vec![], type_protos: vec![], }; let att_dtype = AttributeProto { name: "dtype".to_string(), ref_attr_name: "dtype".to_string(), i: 11, // DOUBLE doc_string: "dtype".to_string(), r#type: 2, // INT f: 0.0, s: vec![], t: None, g: None, sparse_tensor: None, tp: None, floats: vec![], ints: vec![], strings: vec![], tensors: vec![], graphs: vec![], sparse_tensors: vec![], type_protos: vec![], }; let attrs = { let mut mut_attrs = vec![att_shape, att_dtype]; if mean.is_some() { mut_attrs.push(att_mean); } if scale.is_some() { mut_attrs.push(att_scale); } mut_attrs }; let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "RandomNormal".to_string(), domain: "".to_string(), attribute: attrs, input: vec![], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let eval = candle_onnx::simple_eval(&manual_graph, HashMap::new())?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let data = z.flatten_all()?.to_vec1::()?; // test if values are unique for (i, a) in data.iter().enumerate() { for (j, b) in data.iter().enumerate() { if i == j { continue; }; assert_ne!(a, b); } } Ok(()) } Ok(()) } // "Range" #[test] fn test_range() -> Result<()> { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-113 test(1., 5., 2., &[1., 3.])?; // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-113 test(10i64, 6i64, -3i64, &[10i64, 7i64])?; fn test( start: impl NdArray, limit: impl NdArray, delta: impl NdArray, expected: impl NdArray, ) -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Range".to_string(), domain: "".to_string(), attribute: vec![], input: vec![ INPUT_X.to_string(), INPUT_Y.to_string(), INPUT_A.to_string(), ], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(start, &Device::Cpu)?); inputs.insert(INPUT_Y.to_string(), Tensor::new(limit, &Device::Cpu)?); inputs.insert(INPUT_A.to_string(), Tensor::new(delta, &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval .get(OUTPUT_Z) .expect("Output 'z' not found") .to_dtype(DType::F64)?; let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?; match expected.dims().len() { 0 => assert_eq!(z.to_vec0::()?, expected.to_vec0::()?), 1 => assert_eq!(z.to_vec1::()?, expected.to_vec1::()?), 2 => assert_eq!(z.to_vec2::()?, expected.to_vec2::()?), 3 => assert_eq!(z.to_vec3::()?, expected.to_vec3::()?), _ => unreachable!(), }; Ok(()) } Ok(()) } // "Greater" #[test] fn test_greater() -> Result<()> { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-63 test(&[1., 2., 3.], &[3., 2., 1.], &[0u8, 0, 1])?; // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-63 test(&[1., 2., 3.], 2., &[0u8, 0, 1])?; fn test(a: impl NdArray, b: impl NdArray, expected: impl NdArray) -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Greater".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string(), INPUT_Y.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(a, &Device::Cpu)?); inputs.insert(INPUT_Y.to_string(), Tensor::new(b, &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval .get(OUTPUT_Z) .expect("Output 'z' not found") .to_dtype(DType::F64)?; let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?; match expected.dims().len() { 0 => assert_eq!(z.to_vec0::()?, expected.to_vec0::()?), 1 => assert_eq!(z.to_vec1::()?, expected.to_vec1::()?), 2 => assert_eq!(z.to_vec2::()?, expected.to_vec2::()?), 3 => assert_eq!(z.to_vec3::()?, expected.to_vec3::()?), _ => unreachable!(), }; Ok(()) } Ok(()) } // "Less" #[test] fn test_less() -> Result<()> { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-81 test(&[1., 2., 3.], &[3., 2., 1.], &[1u8, 0, 0])?; // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-81 test(&[1., 2., 3.], 2., &[1u8, 0, 0])?; fn test(a: impl NdArray, b: impl NdArray, expected: impl NdArray) -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Less".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string(), INPUT_Y.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(a, &Device::Cpu)?); inputs.insert(INPUT_Y.to_string(), Tensor::new(b, &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval .get(OUTPUT_Z) .expect("Output 'z' not found") .to_dtype(DType::F64)?; let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?; match expected.dims().len() { 0 => assert_eq!(z.to_vec0::()?, expected.to_vec0::()?), 1 => assert_eq!(z.to_vec1::()?, expected.to_vec1::()?), 2 => assert_eq!(z.to_vec2::()?, expected.to_vec2::()?), 3 => assert_eq!(z.to_vec3::()?, expected.to_vec3::()?), _ => unreachable!(), }; Ok(()) } Ok(()) } // "Log" #[test] fn test_log() -> Result<()> { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-82 test(&[1., 10.], &[0., std::f64::consts::LN_10])?; fn test(data: impl NdArray, expected: impl NdArray) -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Log".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let expected = Tensor::new(expected, &Device::Cpu)?; match expected.dims().len() { 0 => assert_eq!(z.to_vec0::()?, expected.to_vec0::()?), 1 => assert_eq!(z.to_vec1::()?, expected.to_vec1::()?), 2 => assert_eq!(z.to_vec2::()?, expected.to_vec2::()?), 3 => assert_eq!(z.to_vec3::()?, expected.to_vec3::()?), _ => unreachable!(), }; Ok(()) } Ok(()) } // "Min" #[test] fn test_min() -> Result<()> { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-94 test(&[3., 2., 1.], &[1., 4., 4.], &[2., 5., 0.], &[1., 2., 0.])?; fn test( a: impl NdArray, b: impl NdArray, c: impl NdArray, expected: impl NdArray, ) -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Min".to_string(), domain: "".to_string(), attribute: vec![], input: vec![ INPUT_X.to_string(), INPUT_Y.to_string(), INPUT_A.to_string(), ], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(a, &Device::Cpu)?); inputs.insert(INPUT_Y.to_string(), Tensor::new(b, &Device::Cpu)?); inputs.insert(INPUT_A.to_string(), Tensor::new(c, &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let expected = Tensor::new(expected, &Device::Cpu)?; match expected.dims().len() { 0 => assert_eq!(z.to_vec0::()?, expected.to_vec0::()?), 1 => assert_eq!(z.to_vec1::()?, expected.to_vec1::()?), 2 => assert_eq!(z.to_vec2::()?, expected.to_vec2::()?), 3 => assert_eq!(z.to_vec3::()?, expected.to_vec3::()?), _ => unreachable!(), }; Ok(()) } Ok(()) } // "Where" #[test] fn test_where() -> Result<()> { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-173 test( &[[1u8, 0], [1, 1]], &[[1i64, 2], [3, 4]], &[[9i64, 8], [7, 6]], &[[1i64, 8], [3, 4]], )?; // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-173 test( &[[1u8, 0], [1, 1]], &[[1., 2.], [3., 4.]], &[[9., 8.], [7., 6.]], &[[1., 8.], [3., 4.]], )?; fn test( condition: impl NdArray, x: impl NdArray, y: impl NdArray, expected: impl NdArray, ) -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Where".to_string(), domain: "".to_string(), attribute: vec![], input: vec![ INPUT_X.to_string(), INPUT_Y.to_string(), INPUT_A.to_string(), ], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(condition, &Device::Cpu)?); inputs.insert(INPUT_Y.to_string(), Tensor::new(x, &Device::Cpu)?); inputs.insert(INPUT_A.to_string(), Tensor::new(y, &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval .get(OUTPUT_Z) .expect("Output 'z' not found") .to_dtype(DType::F64)?; let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?; match expected.dims().len() { 0 => assert_eq!(z.to_vec0::()?, expected.to_vec0::()?), 1 => assert_eq!(z.to_vec1::()?, expected.to_vec1::()?), 2 => assert_eq!(z.to_vec2::()?, expected.to_vec2::()?), 3 => assert_eq!(z.to_vec3::()?, expected.to_vec3::()?), _ => unreachable!(), }; Ok(()) } Ok(()) } #[test] fn test_floor() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Floor".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![ValueInfoProto { name: INPUT_X.to_string(), doc_string: "".to_string(), r#type: None, }], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let x = Tensor::from_vec( // some values taken from https://numpy.org/doc/stable/reference/generated/numpy.floor.html vec![ f64::NAN, f64::INFINITY, f64::NEG_INFINITY, -1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0, ], &[10], &Device::Cpu, )?; let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec1::()?; assert!(results[0].is_nan()); assert_eq!( results[1..], vec![ f64::INFINITY, f64::NEG_INFINITY, -2., -2., -1., 0., 1., 1., 2. ] ); Ok(()) } #[test] fn test_ceil() -> Result<()> { let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "Ceil".to_string(), domain: "".to_string(), attribute: vec![], input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![ValueInfoProto { name: INPUT_X.to_string(), doc_string: "".to_string(), r#type: None, }], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let x = Tensor::from_vec( // some values taken from https://numpy.org/doc/stable/reference/generated/numpy.ceil.html vec![ f64::NAN, f64::INFINITY, f64::NEG_INFINITY, -1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0, ], &[10], &Device::Cpu, )?; let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), x); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; assert_eq!(eval.len(), 1); let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let results = z.to_vec1::()?; assert!(results[0].is_nan()); assert_eq!( results[1..], vec![ f64::INFINITY, f64::NEG_INFINITY, -1., -1., -0., 1., 2., 2., 2. ] ); Ok(()) } // "ArgMin" #[test] fn test_argmin() -> Result<()> { // tests from https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-7 // default_axes_keepdims test( &[ [2u32, 1u32], [3u32, 10u32] ], None, Some(1), None, &[ [0i64, 0i64], ], )?; // keepdims test( &[ [2u32, 1u32], [3u32, 10u32] ], Some(1), Some(1), None, &[ [1i64], [0i64] ], )?; // // negative_axis_keepdims test( &[ [2u32, 1u32], [3u32, 10u32] ], Some(-1), Some(1), None, &[ [1i64], [0i64] ], )?; // no_keepdims test( &[ [2u32, 1u32], [3u32, 10u32] ], None, Some(0), None, &[0i64, 0i64], )?; // tests from https://pytorch.org/docs/stable/generated/torch.argmin.html#torch.argmin test( &[ [0.1139, 0.2254, -0.1381, 0.3687], [1.0100, -1.1975, -0.0102, -0.4732], [-0.9240, 0.1207, -0.7506, -1.0213], [1.7809, -1.2960, 0.9384, 0.1438] ], Some(1), Some(0), None, &[2i64, 1i64, 3i64, 1i64], )?; test( &[ [0.1139, 0.2254, -0.1381, 0.3687], [1.0100, -1.1975, -0.0102, -0.4732], [-0.9240, 0.1207, -0.7506, -1.0213], [1.7809, -1.2960, 0.9384, 0.1438] ], Some(1), None, None, &[[2i64], [1i64], [3i64], [1i64]], )?; fn test(data: impl NdArray, axis: Option, keepdims: Option, select_last_index: Option, expected: impl NdArray) -> Result<()> { let att_axis = AttributeProto { name: "axis".to_string(), ref_attr_name: "axis".to_string(), i: axis.unwrap_or(0), doc_string: "axis".to_string(), r#type: 2, // INT f: 0.0, s: vec![], t: None, g: None, sparse_tensor: None, tp: None, floats: vec![], ints: vec![], strings: vec![], tensors: vec![], graphs: vec![], sparse_tensors: vec![], type_protos: vec![], }; let att_keepdims = AttributeProto { name: "keepdims".to_string(), ref_attr_name: "keepdims".to_string(), i: keepdims.unwrap_or(1), doc_string: "keepdims".to_string(), r#type: 2, // INT f: 0.0, s: vec![], t: None, g: None, sparse_tensor: None, tp: None, floats: vec![], ints: vec![], strings: vec![], tensors: vec![], graphs: vec![], sparse_tensors: vec![], type_protos: vec![], }; let att_select_last_index = AttributeProto { name: "select_last_index".to_string(), ref_attr_name: "select_last_index".to_string(), i: select_last_index.unwrap_or(0), doc_string: "select_last_index".to_string(), r#type: 2, // INT f: 0.0, s: vec![], t: None, g: None, sparse_tensor: None, tp: None, floats: vec![], ints: vec![], strings: vec![], tensors: vec![], graphs: vec![], sparse_tensors: vec![], type_protos: vec![], }; let attrs = { let mut mut_attrs = vec![]; if axis.is_some() { mut_attrs.push(att_axis); } if keepdims.is_some() { mut_attrs.push(att_keepdims); } if select_last_index.is_some() { mut_attrs.push(att_select_last_index); } mut_attrs }; let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "ArgMin".to_string(), domain: "".to_string(), attribute: attrs, input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let expected = Tensor::new(expected, &Device::Cpu)?; match expected.dims().len() { 1 => assert_eq!(z.to_vec1::()?, expected.to_vec1::()?), 2 => assert_eq!(z.to_vec2::()?, expected.to_vec2::()?), _ => unreachable!(), }; Ok(()) } Ok(()) } // "ArgMax" #[test] fn test_argmax() -> Result<()> { // tests from https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-6 // default_axes_keepdims test( &[ [2u32, 1u32], [3u32, 10u32] ], None, Some(1), None, &[ [1i64, 1i64], ], )?; // keepdims test( &[ [2u32, 1u32], [3u32, 10u32] ], Some(1), Some(1), None, &[ [0i64], [1i64] ], )?; // // negative_axis_keepdims test( &[ [2u32, 1u32], [3u32, 10u32] ], Some(-1), Some(1), None, &[ [0i64], [1i64] ], )?; // no_keepdims test( &[ [2u32, 1u32], [3u32, 10u32] ], None, Some(0), None, &[1i64, 1i64], )?; // tests from https://pytorch.org/docs/stable/generated/torch.argmax.html test( &[ [1.3398, 0.2663, -0.2686, 0.2450], [-0.7401, -0.8805, -0.3402, -1.1936], [0.4907, -1.3948, -1.0691, -0.3132], [-1.6092, 0.5419, -0.2993, 0.3195] ], Some(1), Some(0), None, &[0i64, 2i64, 0i64, 1i64], )?; test( &[ [1.3398, 0.2663, -0.2686, 0.2450], [-0.7401, -0.8805, -0.3402, -1.1936], [0.4907, -1.3948, -1.0691, -0.3132], [-1.6092, 0.5419, -0.2993, 0.3195] ], Some(1), None, None, &[[0i64], [2i64], [0i64], [1i64]], )?; fn test(data: impl NdArray, axis: Option, keepdims: Option, select_last_index: Option, expected: impl NdArray) -> Result<()> { let att_axis = AttributeProto { name: "axis".to_string(), ref_attr_name: "axis".to_string(), i: axis.unwrap_or(0), doc_string: "axis".to_string(), r#type: 2, // INT f: 0.0, s: vec![], t: None, g: None, sparse_tensor: None, tp: None, floats: vec![], ints: vec![], strings: vec![], tensors: vec![], graphs: vec![], sparse_tensors: vec![], type_protos: vec![], }; let att_keepdims = AttributeProto { name: "keepdims".to_string(), ref_attr_name: "keepdims".to_string(), i: keepdims.unwrap_or(1), doc_string: "keepdims".to_string(), r#type: 2, // INT f: 0.0, s: vec![], t: None, g: None, sparse_tensor: None, tp: None, floats: vec![], ints: vec![], strings: vec![], tensors: vec![], graphs: vec![], sparse_tensors: vec![], type_protos: vec![], }; let att_select_last_index = AttributeProto { name: "select_last_index".to_string(), ref_attr_name: "select_last_index".to_string(), i: select_last_index.unwrap_or(0), doc_string: "select_last_index".to_string(), r#type: 2, // INT f: 0.0, s: vec![], t: None, g: None, sparse_tensor: None, tp: None, floats: vec![], ints: vec![], strings: vec![], tensors: vec![], graphs: vec![], sparse_tensors: vec![], type_protos: vec![], }; let attrs = { let mut mut_attrs = vec![]; if axis.is_some() { mut_attrs.push(att_axis); } if keepdims.is_some() { mut_attrs.push(att_keepdims); } if select_last_index.is_some() { mut_attrs.push(att_select_last_index); } mut_attrs }; let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "ArgMax".to_string(), domain: "".to_string(), attribute: attrs, input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let expected = Tensor::new(expected, &Device::Cpu)?; match expected.dims().len() { 1 => assert_eq!(z.to_vec1::()?, expected.to_vec1::()?), 2 => assert_eq!(z.to_vec2::()?, expected.to_vec2::()?), _ => unreachable!(), }; Ok(()) } Ok(()) } // "LeakyRelu" #[test] fn test_leakyrelu() -> Result<()> { // tests from https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-80 // leakyrelu test( &[-1.0, 0.0, 1.0], Some(0.1), &[-0.1, 0.0, 1.0] )?; fn test(data: impl NdArray, alpha: Option, expected: impl NdArray) -> Result<()> { let att_alpha = AttributeProto { name: "alpha".to_string(), ref_attr_name: "alpha".to_string(), i: 0, doc_string: "alpha".to_string(), r#type: 1, // FLOAT f: alpha.unwrap_or(0.01), s: vec![], t: None, g: None, sparse_tensor: None, tp: None, floats: vec![], ints: vec![], strings: vec![], tensors: vec![], graphs: vec![], sparse_tensors: vec![], type_protos: vec![], }; let attrs = { let mut mut_attrs = vec![]; if alpha.is_some() { mut_attrs.push(att_alpha); } mut_attrs }; let manual_graph = create_model_proto_with_graph(Some(GraphProto { node: vec![NodeProto { op_type: "LeakyRelu".to_string(), domain: "".to_string(), attribute: attrs, input: vec![INPUT_X.to_string()], output: vec![OUTPUT_Z.to_string()], name: "".to_string(), doc_string: "".to_string(), }], name: "".to_string(), initializer: vec![], input: vec![], output: vec![ValueInfoProto { name: OUTPUT_Z.to_string(), doc_string: "".to_string(), r#type: None, }], value_info: vec![], doc_string: "".to_string(), sparse_initializer: vec![], quantization_annotation: vec![], })); let mut inputs: HashMap = HashMap::new(); inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?); let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); let expected = Tensor::new(expected, &Device::Cpu)?; for both in z.to_vec1::()?.iter().zip(expected.to_vec1::()?.iter()) { let (act, exp) = both; assert!(f64::abs(act - exp) < f32::EPSILON.into()); } Ok(()) } Ok(()) }