mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Handle more tensor shapes in onnx "Gather" operation (#2026)
* Handle more tensor shapes in onnx "Gather" operation * Add more tests * Add comment * Fix typo
This commit is contained in:
@ -508,17 +508,33 @@ pub fn simple_eval(
|
|||||||
values.insert(node.output[0].clone(), xs);
|
values.insert(node.output[0].clone(), xs);
|
||||||
}
|
}
|
||||||
"Gather" => {
|
"Gather" => {
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gather
|
||||||
let xs = get(&node.input[0])?;
|
let xs = get(&node.input[0])?;
|
||||||
let indices = get(&node.input[1])?;
|
let indices = get(&node.input[1])?;
|
||||||
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
|
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
|
||||||
let axis = xs.normalize_axis(axis)?;
|
let axis = xs.normalize_axis(axis)?;
|
||||||
// TODO: Provide an op to handle the ONNX generalized gather op ideally in a
|
|
||||||
// differentiable way.
|
// In Pytorch or Numpy this can be done by indexing the xs tensor using the indices
|
||||||
let xs = if indices.rank() == 0 {
|
// tensor directly, but candle does not support tensor indexing at the moment, so
|
||||||
|
// some workarounds must be done.
|
||||||
|
let xs = match indices.dims() {
|
||||||
|
[] => {
|
||||||
let index = indices.to_vec0::<i64>()? as usize;
|
let index = indices.to_vec0::<i64>()? as usize;
|
||||||
xs.narrow(axis, index, 1)?.squeeze(axis)?
|
xs.narrow(axis, index, 1)?.squeeze(axis)?
|
||||||
} else {
|
}
|
||||||
|
[_] => xs.index_select(indices, axis)?,
|
||||||
|
[first, _] => {
|
||||||
|
let mut v = Vec::with_capacity(*first);
|
||||||
|
for i in 0..*first {
|
||||||
|
v.push(xs.index_select(&indices.get(i)?, axis)?)
|
||||||
|
}
|
||||||
|
Tensor::stack(&v, axis)?
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
// TODO: Provide an op to handle the ONNX generalized gather op ideally in a
|
||||||
|
// differentiable way.
|
||||||
todo!("implement gather for {xs:?} {indices:?} axis {axis}")
|
todo!("implement gather for {xs:?} {indices:?} axis {axis}")
|
||||||
|
}
|
||||||
};
|
};
|
||||||
values.insert(node.output[0].clone(), xs);
|
values.insert(node.output[0].clone(), xs);
|
||||||
}
|
}
|
||||||
|
@ -4,7 +4,7 @@ extern crate intel_mkl_src;
|
|||||||
#[cfg(feature = "accelerate")]
|
#[cfg(feature = "accelerate")]
|
||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
use candle::{Device, Result, Tensor};
|
use candle::{Device, NdArray, Result, Tensor};
|
||||||
use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto};
|
use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
@ -829,7 +829,134 @@ fn test_flatten_operation() -> Result<()> {
|
|||||||
// #[test]
|
// #[test]
|
||||||
|
|
||||||
// "Gather"
|
// "Gather"
|
||||||
// #[test]
|
#[test]
|
||||||
|
fn test_gather_operation() -> Result<()> {
|
||||||
|
// test taken from https://onnx.ai/onnx/operators/onnx__Gather.html#summary.
|
||||||
|
test(
|
||||||
|
&[[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]],
|
||||||
|
&[[0i64, 1], [1, 2]],
|
||||||
|
0,
|
||||||
|
&[[[1.0, 1.2], [2.3, 3.4]], [[2.3, 3.4], [4.5, 5.7]]],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// test taken from https://onnx.ai/onnx/operators/onnx__Gather.html#summary.
|
||||||
|
test(
|
||||||
|
&[[1.0, 1.2, 1.9], [2.3, 3.4, 3.9], [4.5, 5.7, 5.9]],
|
||||||
|
&[[0i64, 2]],
|
||||||
|
1,
|
||||||
|
&[[[1.0, 1.9]], [[2.3, 3.9]], [[4.5, 5.9]]],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// all the tests below are generated from numpy.take, which works like
|
||||||
|
// onnx's Gather operation.
|
||||||
|
test(&[1.0, 2.0, 3.0, 4.0], 3i64, 0, 4.0)?;
|
||||||
|
|
||||||
|
test(&[[1.0, 2.0, 3.0, 4.0]], 3i64, 1, &[4.0])?;
|
||||||
|
|
||||||
|
test(
|
||||||
|
&[[1.0], [2.0], [3.0], [4.0]],
|
||||||
|
&[3i64, 2],
|
||||||
|
0,
|
||||||
|
&[[4.0], [3.0]],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
test(
|
||||||
|
&[
|
||||||
|
[[1.0, 2.0], [3.0, 4.0]],
|
||||||
|
[[5.0, 6.0], [7.0, 8.0]],
|
||||||
|
[[9.0, 10.0], [11.0, 12.0]],
|
||||||
|
[[13.0, 14.0], [15.0, 16.0]],
|
||||||
|
],
|
||||||
|
1i64,
|
||||||
|
0,
|
||||||
|
&[[5.0, 6.0], [7.0, 8.0]],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
test(
|
||||||
|
&[
|
||||||
|
[[1.0, 2.0], [3.0, 4.0]],
|
||||||
|
[[5.0, 6.0], [7.0, 8.0]],
|
||||||
|
[[9.0, 10.0], [11.0, 12.0]],
|
||||||
|
[[13.0, 14.0], [15.0, 16.0]],
|
||||||
|
],
|
||||||
|
&[1i64, 0],
|
||||||
|
0,
|
||||||
|
&[[[5.0, 6.0], [7.0, 8.0]], [[1.0, 2.0], [3.0, 4.0]]],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
fn test(
|
||||||
|
data: impl NdArray,
|
||||||
|
indices: impl NdArray,
|
||||||
|
axis: i64,
|
||||||
|
expected: impl NdArray,
|
||||||
|
) -> Result<()> {
|
||||||
|
let att_axis = AttributeProto {
|
||||||
|
name: "axis".to_string(),
|
||||||
|
ref_attr_name: "axis".to_string(),
|
||||||
|
i: axis,
|
||||||
|
doc_string: "axis".to_string(),
|
||||||
|
r#type: 2,
|
||||||
|
f: 0.0,
|
||||||
|
s: vec![],
|
||||||
|
t: None,
|
||||||
|
g: None,
|
||||||
|
sparse_tensor: None,
|
||||||
|
tp: None,
|
||||||
|
floats: vec![],
|
||||||
|
ints: vec![],
|
||||||
|
strings: vec![],
|
||||||
|
tensors: vec![],
|
||||||
|
graphs: vec![],
|
||||||
|
sparse_tensors: vec![],
|
||||||
|
type_protos: vec![],
|
||||||
|
};
|
||||||
|
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Gather".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![att_axis],
|
||||||
|
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
|
||||||
|
inputs.insert(INPUT_Y.to_string(), Tensor::new(indices, &Device::Cpu)?);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
|
||||||
|
let expected = Tensor::new(expected, &Device::Cpu)?;
|
||||||
|
match expected.dims().len() {
|
||||||
|
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
|
||||||
|
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
|
||||||
|
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
|
||||||
|
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
// "Shape"
|
// "Shape"
|
||||||
#[test]
|
#[test]
|
||||||
|
Reference in New Issue
Block a user