diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index f7203b36..f0679d5b 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -643,6 +643,13 @@ fn simple_eval_( let dims = Tensor::from_vec(dims, xs.rank(), xs.device())?; values.insert(node.output[0].clone(), dims); } + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Size + "Size" => { + let data = get(&node.input[0])?; + let size: usize = data.dims().iter().product(); + let output = Tensor::from_slice(&[size as i64], (), data.device())?; + values.insert(node.output[0].clone(), output); + } // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sqrt "Sqrt" => { let xs = get(&node.input[0])?; diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index 82d38aa4..bf459d5d 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -1112,6 +1112,51 @@ fn test_gather_operation() -> Result<()> { Ok(()) } +// "Size" +#[test] +fn test_size_operation() -> Result<()> { + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Size".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec![INPUT_X.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![ValueInfoProto { + name: INPUT_X.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + let x = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32, 4.0f32], &[2, 2], &Device::Cpu)?; + + let mut inputs: HashMap = HashMap::new(); + inputs.insert(INPUT_X.to_string(), x); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + let results = z.to_scalar::()?; + assert_eq!(results, 4); + + Ok(()) +} + // "Shape" #[test] fn test_shape_operation() -> Result<()> {