From 46073c5f733cbdf1135c36f695d32382d1bdaa51 Mon Sep 17 00:00:00 2001 From: Mateusz Okulus Date: Fri, 19 Apr 2024 16:06:43 +0200 Subject: [PATCH 01/10] Add basic RandomUniform implementation --- candle-onnx/src/eval.rs | 43 ++++++++++++ candle-onnx/tests/ops.rs | 145 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 188 insertions(+) diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 75927822..33040e15 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -820,6 +820,49 @@ pub fn simple_eval( }; values.insert(node.output[0].clone(), output); } + "RandomUniform" => { + let dt: i64 = get_attr_opt(node, "dtype")?.copied().unwrap_or(1); // 1 is float + // type by + // default + let dtype = match DataType::try_from(dt as i32) { + Ok(dt) => match dtype(dt) { + Some(DType::U8 | DType::U32 | DType::I64) => { + bail!( + "unsupported 'dtype' value {dt:?}, only floats are allowed, for RandomUnifrom {}", + node.name + ) + } + Some(dt) => dt, + None => { + bail!( + "unsupported 'dtype' value {dt:?} for RandomUnifrom {}", + node.name + ) + } + }, + Err(_) => { + bail!( + "unsupported 'dtype' value {dt:?} for RandomUniform {}", + node.name + ) + } + }; + let low: f32 = get_attr_opt(node, "low")?.copied().unwrap_or(0.0); + let high: f32 = get_attr_opt(node, "high")?.copied().unwrap_or(1.0); + let seed: Option = get_attr_opt(node, "seed")?.copied(); + match seed { + Some(_) => { + bail!("seed for RandomUniform is currently not supported") + } + None => {} + }; + let shape: Vec = get_attr::<[i64]>(node, "shape")? + .iter() + .map(|x| *x as usize) + .collect(); + let output = Tensor::rand(low, high, shape, &Device::Cpu)?.to_dtype(dtype)?; + values.insert(node.output[0].clone(), output); + } op_type => bail!("unsupported op_type {op_type} for op {node:?}"), } } diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index fda76ec2..a4675115 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -1639,3 +1639,148 @@ fn test_reduce_mean() -> Result<()> { Ok(()) } + +#[test] +fn test_random_uniform() -> Result<()> { + test(vec![3, 2, 1, 4], None, None)?; + test(vec![2, 2, 2, 2], Some(-10.0), None)?; + test(vec![2, 2, 2, 2], None, Some(10.0))?; + test(vec![1, 2, 3, 4], Some(-10.0), Some(10.0))?; + + fn test(shape: Vec, low: Option, high: Option) -> Result<()> { + let att_low = AttributeProto { + name: "low".to_string(), + ref_attr_name: "low".to_string(), + i: 0, + doc_string: "low".to_string(), + r#type: 1, // FLOAT + f: low.unwrap_or(0.0), + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + let att_high = AttributeProto { + name: "high".to_string(), + ref_attr_name: "high".to_string(), + i: 0, + doc_string: "high".to_string(), + r#type: 1, // FLOAT + f: high.unwrap_or(1.0), + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + let att_shape = AttributeProto { + name: "shape".to_string(), + ref_attr_name: "shape".to_string(), + i: 0, + doc_string: "shape".to_string(), + r#type: 7, // INTS + f: 0.0, + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: shape, + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + let att_dtype = AttributeProto { + name: "dtype".to_string(), + ref_attr_name: "dtype".to_string(), + i: 11, // DOUBLE + doc_string: "dtype".to_string(), + r#type: 2, // INT + f: 0.0, + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + let attrs = { + let mut mut_attrs = vec![att_shape, att_dtype]; + if low.is_some() { + mut_attrs.push(att_low); + } + if high.is_some() { + mut_attrs.push(att_high); + } + mut_attrs + }; + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "RandomUniform".to_string(), + domain: "".to_string(), + attribute: attrs, + input: vec![], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + let eval = candle_onnx::simple_eval(&manual_graph, HashMap::new())?; + assert_eq!(eval.len(), 1); + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + let min = z + .flatten_all()? + .to_vec1()? + .into_iter() + .reduce(f64::min) + .unwrap(); + let max = z + .flatten_all()? + .to_vec1()? + .into_iter() + .reduce(f64::max) + .unwrap(); + assert!(min >= low.unwrap_or(0.0).into()); + assert!(max <= high.unwrap_or(1.0).into()); + assert_ne!(min, max); + Ok(()) + } + + Ok(()) +} From 0fa41a791f63c2a74ae0d1d753a476dd0abc3cb0 Mon Sep 17 00:00:00 2001 From: Mateusz Okulus Date: Fri, 19 Apr 2024 16:09:45 +0200 Subject: [PATCH 02/10] Use is_some to check if seed is present --- candle-onnx/src/eval.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 33040e15..8ff8a8d0 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -850,11 +850,8 @@ pub fn simple_eval( let low: f32 = get_attr_opt(node, "low")?.copied().unwrap_or(0.0); let high: f32 = get_attr_opt(node, "high")?.copied().unwrap_or(1.0); let seed: Option = get_attr_opt(node, "seed")?.copied(); - match seed { - Some(_) => { - bail!("seed for RandomUniform is currently not supported") - } - None => {} + if seed.is_some() { + bail!("seed for RandomUniform is currently not supported") }; let shape: Vec = get_attr::<[i64]>(node, "shape")? .iter() From 70388c27b694f421cd6528c6e7e35f6630dfbd21 Mon Sep 17 00:00:00 2001 From: b1rtek <53182944+B1rtek@users.noreply.github.com> Date: Fri, 19 Apr 2024 22:48:05 +0200 Subject: [PATCH 03/10] Added Exp operator implementation --- candle-onnx/src/eval.rs | 5 ++++ candle-onnx/tests/ops.rs | 50 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 75927822..e3f4d09d 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -260,6 +260,11 @@ pub fn simple_eval( let output = input0.broadcast_pow(input1)?; values.insert(node.output[0].clone(), output); } + "Exp" => { + let xs = get(&node.input[0])?; + let output = xs.exp()?; + values.insert(node.output[0].clone(), output); + } "Equal" => { let input0 = get(&node.input[0])?; let input1 = get(&node.input[1])?; diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index fda76ec2..2711d335 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -227,6 +227,56 @@ fn test_div_operation() -> Result<()> { Ok(()) } +// "Exp" +#[test] +fn test_exp_operation() -> Result<()> { + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Exp".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec![INPUT_X.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + + let x = Tensor::from_vec( + vec![-1.0f32, 0.0f32, 1.0f32, 2.0f32], + &[2, 2], + &Device::Cpu, + )?; + + let mut inputs: HashMap = HashMap::new(); + inputs.insert(INPUT_X.to_string(), x); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + + let results = z.to_vec2::()?; + + assert_eq!(results[0][0], 0.36787944f32); + assert_eq!(results[0][1], 1.0f32); + assert_eq!(results[1], vec![std::f32::consts::E, 7.38905609f32]); + + Ok(()) +} + // "Equal" #[test] fn test_equal_operation() -> Result<()> { From 1caf62e4a6454ec762a73d68ba77e5da4d64e2c4 Mon Sep 17 00:00:00 2001 From: b1rtek <53182944+B1rtek@users.noreply.github.com> Date: Thu, 9 May 2024 03:00:15 +0200 Subject: [PATCH 04/10] Added ArgMin operator implementation --- candle-onnx/src/eval.rs | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 78e0554a..746220ca 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -1011,6 +1011,26 @@ pub fn simple_eval( let output = Tensor::rand(low, high, shape, &Device::Cpu)?.to_dtype(dtype)?; values.insert(node.output[0].clone(), output); } + "ArgMin" => { + let input = get(&node.input[0])?; + let axis_i64: i64 = get_attr_opt(node, "axis")?.copied().unwrap_or(0); + let rank_i64: i64 = input.rank().try_into().unwrap(); + if axis_i64 < -rank_i64 || axis_i64 >= rank_i64 { + bail!("axis ({}) out of accepted range [-rank, rank-1] which was [{}, {}]", axis_i64, -rank_i64, rank_i64-1) + } + let axis = input.normalize_axis(axis_i64)?; + let keepdims: i64 = get_attr_opt(node, "keepdims")?.copied().unwrap_or(1); + let select_last_index: i64 = get_attr_opt(node, "select_last_index")?.copied().unwrap_or(0); + if select_last_index == 1 { + bail!("select_last_index for ArgMin is currently not supported") + } + let output = if keepdims == 1 { + input.argmin_keepdim(axis)? + } else { + input.argmin(axis)? + }; + values.insert(node.output[0].clone(), output); + } op_type => bail!("unsupported op_type {op_type} for op {node:?}"), } } From 13b88547f7c0f2837d56e46056386b044e358b2b Mon Sep 17 00:00:00 2001 From: b1rtek <53182944+B1rtek@users.noreply.github.com> Date: Thu, 9 May 2024 03:00:22 +0200 Subject: [PATCH 05/10] Added tests for ArgMin --- candle-onnx/tests/ops.rs | 172 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 172 insertions(+) diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index 294b5511..ed9af9ae 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -10,6 +10,7 @@ use candle_onnx::onnx::attribute_proto::AttributeType; use candle_onnx::onnx::tensor_proto::DataType; use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto}; use std::collections::HashMap; +use candle_onnx::eval::Value; const INPUT_X: &str = "x"; const INPUT_Y: &str = "y"; @@ -2416,3 +2417,174 @@ fn test_where() -> Result<()> { Ok(()) } + +// "ArgMin" +#[test] +fn test_argmin() -> Result<()> { + // tests from https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-7 + // default_axes_keepdims + test( + &[ + [2u32, 1u32], + [3u32, 10u32] + ], + None, + Some(1), + None, + &[ + [0u32, 0u32], + ], + )?; + // keepdims + test( + &[ + [2u32, 1u32], + [3u32, 10u32] + ], + Some(1), + Some(1), + None, + &[ + [1u32], + [0u32] + ], + )?; + // // negative_axis_keepdims + test( + &[ + [2u32, 1u32], + [3u32, 10u32] + ], + Some(-1), + Some(1), + None, + &[ + [1u32], + [0u32] + ], + )?; + // no_keepdims + test( + &[ + [2u32, 1u32], + [3u32, 10u32] + ], + None, + Some(0), + None, + &[0u32, 0u32], + )?; + fn test(data: impl NdArray, axis: Option, keepdims: Option, select_last_index: Option, expected: impl NdArray) -> Result<()> { + let att_axis = AttributeProto { + name: "axis".to_string(), + ref_attr_name: "axis".to_string(), + i: axis.unwrap_or(0), + doc_string: "axis".to_string(), + r#type: 2, // INT + f: 0.0, + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + let att_keepdims = AttributeProto { + name: "keepdims".to_string(), + ref_attr_name: "keepdims".to_string(), + i: keepdims.unwrap_or(1), + doc_string: "keepdims".to_string(), + r#type: 2, // INT + f: 0.0, + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + let att_select_last_index = AttributeProto { + name: "select_last_index".to_string(), + ref_attr_name: "select_last_index".to_string(), + i: select_last_index.unwrap_or(0), + doc_string: "select_last_index".to_string(), + r#type: 2, // INT + f: 0.0, + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + let attrs = { + let mut mut_attrs = vec![]; + if axis.is_some() { + mut_attrs.push(att_axis); + } + if keepdims.is_some() { + mut_attrs.push(att_keepdims); + } + if select_last_index.is_some() { + mut_attrs.push(att_select_last_index); + } + mut_attrs + }; + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "ArgMin".to_string(), + domain: "".to_string(), + attribute: attrs, + input: vec![INPUT_X.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + let mut inputs: HashMap = HashMap::new(); + inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?); + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + + let expected = Tensor::new(expected, &Device::Cpu)?; + match expected.dims().len() { + 1 => assert_eq!(z.to_vec1::()?, expected.to_vec1::()?), + 2 => assert_eq!(z.to_vec2::()?, expected.to_vec2::()?), + _ => unreachable!(), + }; + + Ok(()) + } + + Ok(()) +} From 9a273196b79f9ca71e7ed83f1564abcd9bf17a52 Mon Sep 17 00:00:00 2001 From: b1rtek <53182944+B1rtek@users.noreply.github.com> Date: Thu, 9 May 2024 20:22:22 +0200 Subject: [PATCH 06/10] ArgMin now returns a tensor with i64 --- candle-onnx/src/eval.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 746220ca..558f1161 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -1028,7 +1028,7 @@ pub fn simple_eval( input.argmin_keepdim(axis)? } else { input.argmin(axis)? - }; + }.to_dtype(DType::I64)?; values.insert(node.output[0].clone(), output); } op_type => bail!("unsupported op_type {op_type} for op {node:?}"), From c4743aa570d22e755192f2df8ad72b830a19bf04 Mon Sep 17 00:00:00 2001 From: b1rtek <53182944+B1rtek@users.noreply.github.com> Date: Thu, 9 May 2024 20:22:34 +0200 Subject: [PATCH 07/10] Added tests from pytorch examples --- candle-onnx/tests/ops.rs | 42 +++++++++++++++++++++++++++++++--------- 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index ed9af9ae..09a6edb9 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -10,7 +10,6 @@ use candle_onnx::onnx::attribute_proto::AttributeType; use candle_onnx::onnx::tensor_proto::DataType; use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto}; use std::collections::HashMap; -use candle_onnx::eval::Value; const INPUT_X: &str = "x"; const INPUT_Y: &str = "y"; @@ -2432,7 +2431,7 @@ fn test_argmin() -> Result<()> { Some(1), None, &[ - [0u32, 0u32], + [0i64, 0i64], ], )?; // keepdims @@ -2445,8 +2444,8 @@ fn test_argmin() -> Result<()> { Some(1), None, &[ - [1u32], - [0u32] + [1i64], + [0i64] ], )?; // // negative_axis_keepdims @@ -2459,8 +2458,8 @@ fn test_argmin() -> Result<()> { Some(1), None, &[ - [1u32], - [0u32] + [1i64], + [0i64] ], )?; // no_keepdims @@ -2472,7 +2471,32 @@ fn test_argmin() -> Result<()> { None, Some(0), None, - &[0u32, 0u32], + &[0i64, 0i64], + )?; + // tests from https://pytorch.org/docs/stable/generated/torch.argmin.html#torch.argmin + test( + &[ + [0.1139, 0.2254, -0.1381, 0.3687], + [1.0100, -1.1975, -0.0102, -0.4732], + [-0.9240, 0.1207, -0.7506, -1.0213], + [1.7809, -1.2960, 0.9384, 0.1438] + ], + Some(1), + Some(0), + None, + &[2i64, 1i64, 3i64, 1i64], + )?; + test( + &[ + [0.1139, 0.2254, -0.1381, 0.3687], + [1.0100, -1.1975, -0.0102, -0.4732], + [-0.9240, 0.1207, -0.7506, -1.0213], + [1.7809, -1.2960, 0.9384, 0.1438] + ], + Some(1), + None, + None, + &[[2i64], [1i64], [3i64], [1i64]], )?; fn test(data: impl NdArray, axis: Option, keepdims: Option, select_last_index: Option, expected: impl NdArray) -> Result<()> { let att_axis = AttributeProto { @@ -2578,8 +2602,8 @@ fn test_argmin() -> Result<()> { let expected = Tensor::new(expected, &Device::Cpu)?; match expected.dims().len() { - 1 => assert_eq!(z.to_vec1::()?, expected.to_vec1::()?), - 2 => assert_eq!(z.to_vec2::()?, expected.to_vec2::()?), + 1 => assert_eq!(z.to_vec1::()?, expected.to_vec1::()?), + 2 => assert_eq!(z.to_vec2::()?, expected.to_vec2::()?), _ => unreachable!(), }; From 8f1119b3e0dc5da794d52947881203693468fb44 Mon Sep 17 00:00:00 2001 From: b1rtek <53182944+B1rtek@users.noreply.github.com> Date: Thu, 9 May 2024 20:45:41 +0200 Subject: [PATCH 08/10] Added ArgMax operator implementation --- candle-onnx/src/eval.rs | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 558f1161..4d3f3ee4 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -1031,6 +1031,26 @@ pub fn simple_eval( }.to_dtype(DType::I64)?; values.insert(node.output[0].clone(), output); } + "ArgMax" => { + let input = get(&node.input[0])?; + let axis_i64: i64 = get_attr_opt(node, "axis")?.copied().unwrap_or(0); + let rank_i64: i64 = input.rank().try_into().unwrap(); + if axis_i64 < -rank_i64 || axis_i64 >= rank_i64 { + bail!("axis ({}) out of accepted range [-rank, rank-1] which was [{}, {}]", axis_i64, -rank_i64, rank_i64-1) + } + let axis = input.normalize_axis(axis_i64)?; + let keepdims: i64 = get_attr_opt(node, "keepdims")?.copied().unwrap_or(1); + let select_last_index: i64 = get_attr_opt(node, "select_last_index")?.copied().unwrap_or(0); + if select_last_index == 1 { + bail!("select_last_index for ArgMin is currently not supported") + } + let output = if keepdims == 1 { + input.argmax_keepdim(axis)? + } else { + input.argmax(axis)? + }.to_dtype(DType::I64)?; + values.insert(node.output[0].clone(), output); + } op_type => bail!("unsupported op_type {op_type} for op {node:?}"), } } From 4de76b89a25869b02ef5a83a22aed0b97eab579e Mon Sep 17 00:00:00 2001 From: b1rtek <53182944+B1rtek@users.noreply.github.com> Date: Thu, 9 May 2024 20:45:53 +0200 Subject: [PATCH 09/10] Added tests for ArgMax --- candle-onnx/tests/ops.rs | 196 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 196 insertions(+) diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index 09a6edb9..47f75949 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -2612,3 +2612,199 @@ fn test_argmin() -> Result<()> { Ok(()) } + +// "ArgMin" +#[test] +fn test_argmax() -> Result<()> { + // tests from https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-6 + // default_axes_keepdims + test( + &[ + [2u32, 1u32], + [3u32, 10u32] + ], + None, + Some(1), + None, + &[ + [1i64, 1i64], + ], + )?; + // keepdims + test( + &[ + [2u32, 1u32], + [3u32, 10u32] + ], + Some(1), + Some(1), + None, + &[ + [0i64], + [1i64] + ], + )?; + // // negative_axis_keepdims + test( + &[ + [2u32, 1u32], + [3u32, 10u32] + ], + Some(-1), + Some(1), + None, + &[ + [0i64], + [1i64] + ], + )?; + // no_keepdims + test( + &[ + [2u32, 1u32], + [3u32, 10u32] + ], + None, + Some(0), + None, + &[1i64, 1i64], + )?; + // tests from https://pytorch.org/docs/stable/generated/torch.argmax.html + test( + &[ + [1.3398, 0.2663, -0.2686, 0.2450], + [-0.7401, -0.8805, -0.3402, -1.1936], + [0.4907, -1.3948, -1.0691, -0.3132], + [-1.6092, 0.5419, -0.2993, 0.3195] + ], + Some(1), + Some(0), + None, + &[0i64, 2i64, 0i64, 1i64], + )?; + test( + &[ + [1.3398, 0.2663, -0.2686, 0.2450], + [-0.7401, -0.8805, -0.3402, -1.1936], + [0.4907, -1.3948, -1.0691, -0.3132], + [-1.6092, 0.5419, -0.2993, 0.3195] + ], + Some(1), + None, + None, + &[[0i64], [2i64], [0i64], [1i64]], + )?; + fn test(data: impl NdArray, axis: Option, keepdims: Option, select_last_index: Option, expected: impl NdArray) -> Result<()> { + let att_axis = AttributeProto { + name: "axis".to_string(), + ref_attr_name: "axis".to_string(), + i: axis.unwrap_or(0), + doc_string: "axis".to_string(), + r#type: 2, // INT + f: 0.0, + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + let att_keepdims = AttributeProto { + name: "keepdims".to_string(), + ref_attr_name: "keepdims".to_string(), + i: keepdims.unwrap_or(1), + doc_string: "keepdims".to_string(), + r#type: 2, // INT + f: 0.0, + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + let att_select_last_index = AttributeProto { + name: "select_last_index".to_string(), + ref_attr_name: "select_last_index".to_string(), + i: select_last_index.unwrap_or(0), + doc_string: "select_last_index".to_string(), + r#type: 2, // INT + f: 0.0, + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + let attrs = { + let mut mut_attrs = vec![]; + if axis.is_some() { + mut_attrs.push(att_axis); + } + if keepdims.is_some() { + mut_attrs.push(att_keepdims); + } + if select_last_index.is_some() { + mut_attrs.push(att_select_last_index); + } + mut_attrs + }; + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "ArgMax".to_string(), + domain: "".to_string(), + attribute: attrs, + input: vec![INPUT_X.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + let mut inputs: HashMap = HashMap::new(); + inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?); + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + + let expected = Tensor::new(expected, &Device::Cpu)?; + match expected.dims().len() { + 1 => assert_eq!(z.to_vec1::()?, expected.to_vec1::()?), + 2 => assert_eq!(z.to_vec2::()?, expected.to_vec2::()?), + _ => unreachable!(), + }; + + Ok(()) + } + + Ok(()) +} From 08fd7f7119a66598d681e3b168f8140ffec9788b Mon Sep 17 00:00:00 2001 From: b1rtek <53182944+B1rtek@users.noreply.github.com> Date: Fri, 10 May 2024 00:51:01 +0200 Subject: [PATCH 10/10] Typo fix --- candle-onnx/tests/ops.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index 47f75949..a69c9083 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -2613,7 +2613,7 @@ fn test_argmin() -> Result<()> { Ok(()) } -// "ArgMin" +// "ArgMax" #[test] fn test_argmax() -> Result<()> { // tests from https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-6