Improve the ONNX basic example + bugfixes (#1266)

* Generate some zeros tensor in the onnx simple-eval example.

* Fix the casting operation.

* Support more ops.

* Handle reshape.

* Concat.

* Softmax.
This commit is contained in:
Laurent Mazare
2023-11-04 10:02:47 +01:00
committed by GitHub
parent f7c957d64f
commit bc9a1bf239
3 changed files with 190 additions and 52 deletions

View File

@ -41,9 +41,39 @@ pub fn main() -> Result<()> {
.unwrap()
.input
.iter()
.map(|name| {
let value = Tensor::new(&[-3.2, 2.7], &Device::Cpu)?;
Ok((name.name.clone(), value))
.map(|input| {
use candle_onnx::onnx::tensor_proto::DataType;
let type_ = input.r#type.as_ref().expect("no type for input");
let type_ = type_.value.as_ref().expect("no type.value for input");
let value = match type_ {
candle_onnx::onnx::type_proto::Value::TensorType(tt) => {
let dt = match DataType::try_from(tt.elem_type) {
Ok(dt) => match candle_onnx::dtype(dt) {
Some(dt) => dt,
None => {
anyhow::bail!(
"unsupported 'value' data-type {dt:?} for {}",
input.name
)
}
},
type_ => anyhow::bail!("unsupported input type {type_:?}"),
};
let shape = tt.shape.as_ref().expect("no tensortype.shape for input");
let dims = shape
.dim
.iter()
.map(|dim| match dim.value.as_ref().expect("no dim value") {
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize),
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimParam(_) => anyhow::bail!("DimParam is unsupported for input {}", input.name),
})
.collect::<Result<Vec<usize>>>()?;
Tensor::zeros(dims, dt, &Device::Cpu)?
}
type_ => anyhow::bail!("unsupported input type {type_:?}"),
};
Ok::<_, anyhow::Error>((input.name.clone(), value))
})
.collect::<Result<_>>()?;
let outputs = candle_onnx::simple_eval(&model, inputs)?;

View File

@ -1,9 +1,22 @@
use crate::onnx;
use crate::onnx::tensor_proto::DataType;
use candle::{bail, DType, Device, Result, Tensor};
use std::collections::HashMap;
pub type Value = Tensor;
pub fn dtype(dt: DataType) -> Option<DType> {
match dt {
DataType::Uint8 => Some(DType::U8),
DataType::Uint32 => Some(DType::U32),
DataType::Int64 => Some(DType::I64),
DataType::Float16 => Some(DType::F16),
DataType::Float => Some(DType::F32),
DataType::Double => Some(DType::F64),
_ => None,
}
}
// This function provides a direct evaluation of the proto.
// Longer-term, we should first convert the proto to an intermediate representation of the compute
// graph so as to make multiple evaluations more efficient.
@ -26,6 +39,26 @@ pub fn simple_eval(
Some(value) => Ok(value),
None => bail!("cannot find {input_name} for op {}", node.name),
};
let get_attr_i = |name: &str| match node.attribute.iter().find(|attr| attr.name == name) {
None => {
bail!(
"cannot find the '{name}' attribute in '{}' for {}",
node.op_type,
node.name
)
}
Some(dt) => {
match dt.r#type() {
AttributeType::Int => (),
rtype => bail!(
"unsupported type {rtype:?} for '{name}' attribute in '{}' for {}",
node.op_type,
node.name
),
}
Ok(dt.i)
}
};
// TODO: Validate node.input for each operator.
match node.op_type.as_str() {
"Add" => {
@ -52,12 +85,114 @@ pub fn simple_eval(
let output = input0.broadcast_div(input1)?;
values.insert(node.output[0].clone(), output);
}
"Equal" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.eq(input1)?;
values.insert(node.output[0].clone(), output);
}
"MatMul" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_matmul(input1)?;
values.insert(node.output[0].clone(), output);
}
"Reshape" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?.to_vec1::<i64>()?;
// TODO: Check that there is at most a single -1, handle other neg values.
let input1 = input1
.iter()
.map(|&v| {
if v == -1 {
input0.elem_count()
} else {
v as usize
}
})
.collect::<Vec<usize>>();
let output = input0.reshape(input1)?;
values.insert(node.output[0].clone(), output);
}
"Softmax" => {
let input = get(&node.input[0])?;
let output = match get_attr_i("axis") {
Err(_) => candle_nn::ops::softmax_last_dim(input)?,
Ok(axis) => {
let num_axis = input.rank() as i64;
let axis = if axis >= 0 {
axis as usize
} else if axis < -num_axis {
bail!("wrong axis in concat {axis} for shape {:?}", input.shape())
} else {
(num_axis - axis) as usize
};
candle_nn::ops::softmax(input, axis)?
}
};
values.insert(node.output[0].clone(), output);
}
"Concat" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Concat
let inputs = node
.input
.iter()
.map(|n| Ok(get(n.as_str())?.clone()))
.collect::<Result<Vec<Value>>>()?;
let axis = get_attr_i("axis")?;
let num_axis = if inputs.is_empty() {
bail!("empty concat")
} else {
inputs[0].rank() as i64
};
let axis = if axis >= 0 {
axis as usize
} else if axis < -num_axis {
bail!(
"wrong axis in concat {axis} for shape {:?}",
inputs[0].shape()
)
} else {
(num_axis - axis) as usize
};
let output = Tensor::cat(&inputs, axis)?;
values.insert(node.output[0].clone(), output);
}
"Abs" => {
let input = get(&node.input[0])?;
let output = input.abs()?;
values.insert(node.output[0].clone(), output);
}
"Cos" => {
let input = get(&node.input[0])?;
let output = input.cos()?;
values.insert(node.output[0].clone(), output);
}
"Sin" => {
let input = get(&node.input[0])?;
let output = input.sin()?;
values.insert(node.output[0].clone(), output);
}
"Neg" => {
let input = get(&node.input[0])?;
let output = input.neg()?;
values.insert(node.output[0].clone(), output);
}
"Erf" => {
let input = get(&node.input[0])?;
let output = input.erf()?;
values.insert(node.output[0].clone(), output);
}
"Tanh" => {
let input = get(&node.input[0])?;
let output = input.tanh()?;
values.insert(node.output[0].clone(), output);
}
"Sigmoid" => {
let input = get(&node.input[0])?;
let output = candle_nn::ops::sigmoid(input)?;
values.insert(node.output[0].clone(), output);
}
"Gelu" => {
let input = get(&node.input[0])?;
let output = input.gelu_erf()?;
@ -79,49 +214,20 @@ pub fn simple_eval(
};
let output = match value.r#type() {
AttributeType::Tensor => {
use crate::onnx::tensor_proto::DataType;
let t = value.t.as_ref().unwrap();
let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect();
match DataType::try_from(t.data_type) {
Ok(DataType::Uint8) => Tensor::from_raw_buffer(
t.raw_data.as_slice(),
DType::U8,
dims.as_slice(),
&Device::Cpu,
)?,
Ok(DataType::Uint32) => Tensor::from_raw_buffer(
t.raw_data.as_slice(),
DType::U32,
dims.as_slice(),
&Device::Cpu,
)?,
Ok(DataType::Int64) => Tensor::from_raw_buffer(
t.raw_data.as_slice(),
DType::I64,
dims.as_slice(),
&Device::Cpu,
)?,
Ok(DataType::Float16) => Tensor::from_raw_buffer(
t.raw_data.as_slice(),
DType::F16,
dims.as_slice(),
&Device::Cpu,
)?,
Ok(DataType::Float) => Tensor::from_raw_buffer(
t.raw_data.as_slice(),
DType::F32,
dims.as_slice(),
&Device::Cpu,
)?,
Ok(DataType::Double) => Tensor::from_raw_buffer(
t.raw_data.as_slice(),
DType::F64,
dims.as_slice(),
&Device::Cpu,
)?,
Ok(dt) => {
bail!("unsupported 'value' data-type {dt:?} for {}", node.name)
}
Ok(dt) => match dtype(dt) {
Some(dt) => Tensor::from_raw_buffer(
t.raw_data.as_slice(),
dt,
dims.as_slice(),
&Device::Cpu,
)?,
None => {
bail!("unsupported 'value' data-type {dt:?} for {}", node.name)
}
},
Err(_) => {
bail!(
"unsupported 'value' data-type {} for {}",
@ -138,15 +244,17 @@ pub fn simple_eval(
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast
"Cast" => {
let input = get(&node.input[0])?;
let dtype = match node.attribute.iter().find(|attr| attr.name == "to") {
None => {
bail!("cannot find the 'to' attribute in 'Cast' for {}", node.name)
}
Some(dtype) => match dtype.r#type() {
AttributeType::Floats => candle::DType::F32,
AttributeType::Int => candle::DType::I64,
rtype => bail!("unsupported 'to' type {rtype:?} for {}", node.name),
let dt = get_attr_i("to")?;
let dtype = match DataType::try_from(dt as i32) {
Ok(dt) => match dtype(dt) {
Some(dt) => dt,
None => {
bail!("unsupported 'to' value {dt:?} for cast {}", node.name)
}
},
Err(_) => {
bail!("unsupported 'to' value {dt:?} for cast {}", node.name)
}
};
let output = input.to_dtype(dtype)?;
values.insert(node.output[0].clone(), output);

View File

@ -6,7 +6,7 @@ pub mod onnx {
}
mod eval;
pub use eval::simple_eval;
pub use eval::{dtype, simple_eval};
pub fn read_file<P: AsRef<std::path::Path>>(p: P) -> Result<onnx::ModelProto> {
let buf = std::fs::read(p)?;