mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
onnx: implement Size op (#2316)
This commit is contained in:
@ -643,6 +643,13 @@ fn simple_eval_(
|
||||
let dims = Tensor::from_vec(dims, xs.rank(), xs.device())?;
|
||||
values.insert(node.output[0].clone(), dims);
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Size
|
||||
"Size" => {
|
||||
let data = get(&node.input[0])?;
|
||||
let size: usize = data.dims().iter().product();
|
||||
let output = Tensor::from_slice(&[size as i64], (), data.device())?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sqrt
|
||||
"Sqrt" => {
|
||||
let xs = get(&node.input[0])?;
|
||||
|
@ -1112,6 +1112,51 @@ fn test_gather_operation() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Size"
|
||||
#[test]
|
||||
fn test_size_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Size".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![ValueInfoProto {
|
||||
name: INPUT_X.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
let x = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32, 4.0f32], &[2, 2], &Device::Cpu)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let results = z.to_scalar::<i64>()?;
|
||||
assert_eq!(results, 4);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Shape"
|
||||
#[test]
|
||||
fn test_shape_operation() -> Result<()> {
|
||||
|
Reference in New Issue
Block a user