mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Handle more tensor shapes in onnx "Gather" operation (#2026)
* Handle more tensor shapes in onnx "Gather" operation * Add more tests * Add comment * Fix typo
This commit is contained in:
@ -508,17 +508,33 @@ pub fn simple_eval(
|
||||
values.insert(node.output[0].clone(), xs);
|
||||
}
|
||||
"Gather" => {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gather
|
||||
let xs = get(&node.input[0])?;
|
||||
let indices = get(&node.input[1])?;
|
||||
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
|
||||
let axis = xs.normalize_axis(axis)?;
|
||||
// TODO: Provide an op to handle the ONNX generalized gather op ideally in a
|
||||
// differentiable way.
|
||||
let xs = if indices.rank() == 0 {
|
||||
let index = indices.to_vec0::<i64>()? as usize;
|
||||
xs.narrow(axis, index, 1)?.squeeze(axis)?
|
||||
} else {
|
||||
todo!("implement gather for {xs:?} {indices:?} axis {axis}")
|
||||
|
||||
// In Pytorch or Numpy this can be done by indexing the xs tensor using the indices
|
||||
// tensor directly, but candle does not support tensor indexing at the moment, so
|
||||
// some workarounds must be done.
|
||||
let xs = match indices.dims() {
|
||||
[] => {
|
||||
let index = indices.to_vec0::<i64>()? as usize;
|
||||
xs.narrow(axis, index, 1)?.squeeze(axis)?
|
||||
}
|
||||
[_] => xs.index_select(indices, axis)?,
|
||||
[first, _] => {
|
||||
let mut v = Vec::with_capacity(*first);
|
||||
for i in 0..*first {
|
||||
v.push(xs.index_select(&indices.get(i)?, axis)?)
|
||||
}
|
||||
Tensor::stack(&v, axis)?
|
||||
}
|
||||
_ => {
|
||||
// TODO: Provide an op to handle the ONNX generalized gather op ideally in a
|
||||
// differentiable way.
|
||||
todo!("implement gather for {xs:?} {indices:?} axis {axis}")
|
||||
}
|
||||
};
|
||||
values.insert(node.output[0].clone(), xs);
|
||||
}
|
||||
|
Reference in New Issue
Block a user