mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
ONNX: GatherElements, Xor (#2568)
This commit is contained in:

committed by
GitHub

parent
dcd83336b6
commit
7c09215ef4
@ -670,6 +670,49 @@ fn simple_eval_(
|
||||
};
|
||||
values.insert(node.output[0].clone(), xs);
|
||||
}
|
||||
// https://onnx.ai/onnx/operators/onnx__GatherElements.html#gatherelements
|
||||
// A Note to fellow lurkers:
|
||||
// The numpy based `gather_elements` implementation in `onnx` tests [here](https://github.com/onnx/onnx/blob/main/onnx/backend/test/case/node/gatherelements.py)
|
||||
// and examples is incorrect.
|
||||
// Use `torch.gather` for the validating/ verifying against the proper behaviour
|
||||
"GatherElements" => {
|
||||
let data = get(&node.input[0])?;
|
||||
let indices = get(&node.input[1])?;
|
||||
|
||||
let rank = data.rank();
|
||||
if rank != indices.rank() {
|
||||
bail!("indices must have same rank as input data. Data rank [{}] != indices rank [{}]", data.rank(), indices.rank());
|
||||
}
|
||||
|
||||
let axis = {
|
||||
let axis_i64 = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
|
||||
let axis = data.normalize_axis(axis_i64)?;
|
||||
|
||||
if axis >= rank {
|
||||
bail!(
|
||||
"axis ({}) out of accepted range [-rank, rank-1] which was [-{rank}, {}]",
|
||||
axis_i64,
|
||||
rank - 1
|
||||
)
|
||||
}
|
||||
|
||||
axis
|
||||
};
|
||||
|
||||
// index_select does not support negative indices, so normalize them
|
||||
// to positive indices.
|
||||
let indices = &{
|
||||
let zeros = Tensor::zeros(indices.shape(), indices.dtype(), indices.device())?;
|
||||
let max = Tensor::new(data.dims()[axis] as i64, indices.device())?
|
||||
.to_dtype(indices.dtype())?;
|
||||
let mask = indices.lt(&zeros)?;
|
||||
mask.to_dtype(indices.dtype())?
|
||||
.broadcast_mul(&max)?
|
||||
.add(indices)?
|
||||
};
|
||||
|
||||
values.insert(node.output[0].clone(), data.gather(indices, axis)?);
|
||||
}
|
||||
"Shape" => {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape
|
||||
let xs = get(&node.input[0])?;
|
||||
@ -1891,6 +1934,16 @@ fn simple_eval_(
|
||||
);
|
||||
}
|
||||
}
|
||||
// https://onnx.ai/onnx/operators/onnx__Xor.html
|
||||
"Xor" => {
|
||||
// Since we don't have a `DType::Bool` yet, this ensures that we are working with `0`(False) & `1`(True)
|
||||
let a = get(&node.input[0])?.gt(0_u8)?;
|
||||
let b = get(&node.input[1])?.gt(0_u8)?;
|
||||
|
||||
let out = a.broadcast_add(&b)?.eq(1_u8)?;
|
||||
|
||||
values.insert(node.output[0].clone(), out);
|
||||
}
|
||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user