onnx: implement Size op (#2316)

This commit is contained in:
shua
2024-07-07 19:56:36 +02:00
committed by GitHub
parent 9182f9f5c2
commit eec11ce2ce
2 changed files with 52 additions and 0 deletions

View File

@ -643,6 +643,13 @@ fn simple_eval_(
let dims = Tensor::from_vec(dims, xs.rank(), xs.device())?;
values.insert(node.output[0].clone(), dims);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Size
"Size" => {
let data = get(&node.input[0])?;
let size: usize = data.dims().iter().product();
let output = Tensor::from_slice(&[size as i64], (), data.device())?;
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sqrt
"Sqrt" => {
let xs = get(&node.input[0])?;

View File

@ -1112,6 +1112,51 @@ fn test_gather_operation() -> Result<()> {
Ok(())
}
// "Size"
#[test]
fn test_size_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Size".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32, 4.0f32], &[2, 2], &Device::Cpu)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_scalar::<i64>()?;
assert_eq!(results, 4);
Ok(())
}
// "Shape"
#[test]
fn test_shape_operation() -> Result<()> {