Add missing onnx operations (#2096)

* Add missing onnx operations

* Add tests and fix errors

* Run rustfmt
This commit is contained in:
Gabriel
2024-04-20 18:44:22 +02:00
committed by GitHub
parent 52ae332910
commit 9215e9ce8c
2 changed files with 736 additions and 9 deletions

View File

@ -23,6 +23,11 @@ trait Attr {
fn get(attr: &onnx::AttributeProto) -> Result<&Self>;
}
trait AttrOwned: Sized {
const TYPE: AttributeType;
fn get(attr: &onnx::AttributeProto) -> Result<Self>;
}
impl Attr for i64 {
const TYPE: AttributeType = AttributeType::Int;
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
@ -51,6 +56,50 @@ impl Attr for str {
}
}
impl AttrOwned for Tensor {
const TYPE: AttributeType = AttributeType::Tensor;
fn get(attr: &onnx::AttributeProto) -> Result<Self> {
let tensor_proto = match &attr.t {
Some(value) => value,
None => bail!(
"attribute {} was of type TENSOR, but no tensor was found",
attr.name
),
};
let data_type = match DataType::try_from(tensor_proto.data_type) {
Ok(value) => value,
Err(_) => bail!(
"attribute {} of type TENSOR was an invalid data_type number {}",
attr.name,
tensor_proto.data_type
),
};
let dtype = match dtype(data_type) {
Some(value) => value,
None => bail!(
"attribute {} of type TENSOR has an unsupported data_type {}",
attr.name,
data_type.as_str_name()
),
};
let mut dims = Vec::with_capacity(tensor_proto.dims.len());
for dim in &tensor_proto.dims {
if dim < &0 {
bail!(
"attribute {} of type TENSOR has a negative dimension, which is unsupported",
attr.name
)
}
dims.push(*dim as usize)
}
Tensor::from_raw_buffer(&tensor_proto.raw_data, dtype, &dims, &Device::Cpu)
}
}
fn get_attr_<'a>(node: &'a onnx::NodeProto, name: &str) -> Result<&'a onnx::AttributeProto> {
match node.attribute.iter().find(|attr| attr.name == name) {
None => {
@ -98,6 +147,24 @@ fn get_attr_opt<'a, T: Attr + ?Sized>(
}
}
fn get_attr_opt_owned<T: AttrOwned>(node: &onnx::NodeProto, name: &str) -> Result<Option<T>> {
match node.attribute.iter().find(|attr| attr.name == name) {
None => Ok(None),
Some(attr) => {
if attr.r#type() != T::TYPE {
bail!(
"unsupported type {:?} for '{name}' attribute in '{}' for {}",
attr.r#type,
node.op_type,
node.name
)
}
let val = T::get(attr)?;
Ok(Some(val))
}
}
}
pub fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect();
match DataType::try_from(t.data_type) {
@ -458,14 +525,17 @@ pub fn simple_eval(
}
values.insert(node.output[0].clone(), xs);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#ConstantOfShape
"ConstantOfShape" => {
let dims = get(&node.input[0])?;
let shape = dims
.to_vec1::<i64>()?
.into_iter()
.map(|v| v as usize)
.collect::<Vec<_>>();
let xs = Tensor::zeros(shape, DType::F32, dims.device())?;
let input = get(&node.input[0])?;
let value = get_attr_opt_owned::<Tensor>(node, "value")?.unwrap_or(Tensor::zeros(
(),
DType::F32,
&Device::Cpu,
)?);
let xs = Tensor::ones(input.shape(), value.dtype(), input.device())?
.broadcast_mul(&value)?;
values.insert(node.output[0].clone(), xs);
}
"Unsqueeze" => {
@ -552,6 +622,82 @@ pub fn simple_eval(
let dims = Tensor::from_vec(dims, xs.rank(), xs.device())?;
values.insert(node.output[0].clone(), dims);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sqrt
"Sqrt" => {
let xs = get(&node.input[0])?;
let output = xs.sqrt()?;
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Range
"Range" => {
let start = get(&node.input[0])?;
let limit = get(&node.input[1])?;
let delta = get(&node.input[2])?;
macro_rules! arange_step {
($t: ty) => {
Tensor::arange_step(
start.to_vec0::<$t>()?,
limit.to_vec0::<$t>()?,
delta.to_vec0::<$t>()?,
&Device::Cpu,
)?
};
}
let output = match start.dtype() {
DType::U8 => arange_step!(u8),
DType::U32 => arange_step!(u32),
DType::I64 => arange_step!(i64),
DType::BF16 => arange_step!(f32),
DType::F16 => arange_step!(f32),
DType::F32 => arange_step!(f32),
DType::F64 => arange_step!(f64),
};
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Greater
"Greater" => {
let a = get(&node.input[0])?;
let b = get(&node.input[1])?;
let output = a.broadcast_gt(b)?;
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Less
"Less" => {
let a = get(&node.input[0])?;
let b = get(&node.input[1])?;
let output = a.broadcast_lt(b)?;
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Log
"Log" => {
let a = get(&node.input[0])?;
let output = a.log()?;
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Min
"Min" => {
let mut output = get(&node.input[0])?.clone();
for input in node.input.iter() {
let input = get(input)?;
output = output.broadcast_minimum(input)?
}
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Where
"Where" => {
let cond = get(&node.input[0])?;
let a = get(&node.input[1])?;
let b = get(&node.input[2])?;
let output = cond.where_cond(a, b)?;
values.insert(node.output[0].clone(), output);
}
"Conv" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;