Handle more tensor shapes in onnx "Gather" operation (#2026)

* Handle more tensor shapes in onnx "Gather" operation

* Add more tests

* Add comment

* Fix typo
This commit is contained in:
Gabriel
2024-04-08 14:06:14 +02:00
committed by GitHub
parent 718671a0d5
commit 798e0335cd
2 changed files with 152 additions and 9 deletions

View File

@ -508,17 +508,33 @@ pub fn simple_eval(
values.insert(node.output[0].clone(), xs);
}
"Gather" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gather
let xs = get(&node.input[0])?;
let indices = get(&node.input[1])?;
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
let axis = xs.normalize_axis(axis)?;
// TODO: Provide an op to handle the ONNX generalized gather op ideally in a
// differentiable way.
let xs = if indices.rank() == 0 {
let index = indices.to_vec0::<i64>()? as usize;
xs.narrow(axis, index, 1)?.squeeze(axis)?
} else {
todo!("implement gather for {xs:?} {indices:?} axis {axis}")
// In Pytorch or Numpy this can be done by indexing the xs tensor using the indices
// tensor directly, but candle does not support tensor indexing at the moment, so
// some workarounds must be done.
let xs = match indices.dims() {
[] => {
let index = indices.to_vec0::<i64>()? as usize;
xs.narrow(axis, index, 1)?.squeeze(axis)?
}
[_] => xs.index_select(indices, axis)?,
[first, _] => {
let mut v = Vec::with_capacity(*first);
for i in 0..*first {
v.push(xs.index_select(&indices.get(i)?, axis)?)
}
Tensor::stack(&v, axis)?
}
_ => {
// TODO: Provide an op to handle the ONNX generalized gather op ideally in a
// differentiable way.
todo!("implement gather for {xs:?} {indices:?} axis {axis}")
}
};
values.insert(node.output[0].clone(), xs);
}