Improve the ONNX basic example + bugfixes (#1266)

* Generate some zeros tensor in the onnx simple-eval example.

* Fix the casting operation.

* Support more ops.

* Handle reshape.

* Concat.

* Softmax.
This commit is contained in:
Laurent Mazare
2023-11-04 10:02:47 +01:00
committed by GitHub
parent f7c957d64f
commit bc9a1bf239
3 changed files with 190 additions and 52 deletions

View File

@ -41,9 +41,39 @@ pub fn main() -> Result<()> {
.unwrap() .unwrap()
.input .input
.iter() .iter()
.map(|name| { .map(|input| {
let value = Tensor::new(&[-3.2, 2.7], &Device::Cpu)?; use candle_onnx::onnx::tensor_proto::DataType;
Ok((name.name.clone(), value))
let type_ = input.r#type.as_ref().expect("no type for input");
let type_ = type_.value.as_ref().expect("no type.value for input");
let value = match type_ {
candle_onnx::onnx::type_proto::Value::TensorType(tt) => {
let dt = match DataType::try_from(tt.elem_type) {
Ok(dt) => match candle_onnx::dtype(dt) {
Some(dt) => dt,
None => {
anyhow::bail!(
"unsupported 'value' data-type {dt:?} for {}",
input.name
)
}
},
type_ => anyhow::bail!("unsupported input type {type_:?}"),
};
let shape = tt.shape.as_ref().expect("no tensortype.shape for input");
let dims = shape
.dim
.iter()
.map(|dim| match dim.value.as_ref().expect("no dim value") {
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize),
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimParam(_) => anyhow::bail!("DimParam is unsupported for input {}", input.name),
})
.collect::<Result<Vec<usize>>>()?;
Tensor::zeros(dims, dt, &Device::Cpu)?
}
type_ => anyhow::bail!("unsupported input type {type_:?}"),
};
Ok::<_, anyhow::Error>((input.name.clone(), value))
}) })
.collect::<Result<_>>()?; .collect::<Result<_>>()?;
let outputs = candle_onnx::simple_eval(&model, inputs)?; let outputs = candle_onnx::simple_eval(&model, inputs)?;

View File

@ -1,9 +1,22 @@
use crate::onnx; use crate::onnx;
use crate::onnx::tensor_proto::DataType;
use candle::{bail, DType, Device, Result, Tensor}; use candle::{bail, DType, Device, Result, Tensor};
use std::collections::HashMap; use std::collections::HashMap;
pub type Value = Tensor; pub type Value = Tensor;
pub fn dtype(dt: DataType) -> Option<DType> {
match dt {
DataType::Uint8 => Some(DType::U8),
DataType::Uint32 => Some(DType::U32),
DataType::Int64 => Some(DType::I64),
DataType::Float16 => Some(DType::F16),
DataType::Float => Some(DType::F32),
DataType::Double => Some(DType::F64),
_ => None,
}
}
// This function provides a direct evaluation of the proto. // This function provides a direct evaluation of the proto.
// Longer-term, we should first convert the proto to an intermediate representation of the compute // Longer-term, we should first convert the proto to an intermediate representation of the compute
// graph so as to make multiple evaluations more efficient. // graph so as to make multiple evaluations more efficient.
@ -26,6 +39,26 @@ pub fn simple_eval(
Some(value) => Ok(value), Some(value) => Ok(value),
None => bail!("cannot find {input_name} for op {}", node.name), None => bail!("cannot find {input_name} for op {}", node.name),
}; };
let get_attr_i = |name: &str| match node.attribute.iter().find(|attr| attr.name == name) {
None => {
bail!(
"cannot find the '{name}' attribute in '{}' for {}",
node.op_type,
node.name
)
}
Some(dt) => {
match dt.r#type() {
AttributeType::Int => (),
rtype => bail!(
"unsupported type {rtype:?} for '{name}' attribute in '{}' for {}",
node.op_type,
node.name
),
}
Ok(dt.i)
}
};
// TODO: Validate node.input for each operator. // TODO: Validate node.input for each operator.
match node.op_type.as_str() { match node.op_type.as_str() {
"Add" => { "Add" => {
@ -52,12 +85,114 @@ pub fn simple_eval(
let output = input0.broadcast_div(input1)?; let output = input0.broadcast_div(input1)?;
values.insert(node.output[0].clone(), output); values.insert(node.output[0].clone(), output);
} }
"Equal" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.eq(input1)?;
values.insert(node.output[0].clone(), output);
}
"MatMul" => { "MatMul" => {
let input0 = get(&node.input[0])?; let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?; let input1 = get(&node.input[1])?;
let output = input0.broadcast_matmul(input1)?; let output = input0.broadcast_matmul(input1)?;
values.insert(node.output[0].clone(), output); values.insert(node.output[0].clone(), output);
} }
"Reshape" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?.to_vec1::<i64>()?;
// TODO: Check that there is at most a single -1, handle other neg values.
let input1 = input1
.iter()
.map(|&v| {
if v == -1 {
input0.elem_count()
} else {
v as usize
}
})
.collect::<Vec<usize>>();
let output = input0.reshape(input1)?;
values.insert(node.output[0].clone(), output);
}
"Softmax" => {
let input = get(&node.input[0])?;
let output = match get_attr_i("axis") {
Err(_) => candle_nn::ops::softmax_last_dim(input)?,
Ok(axis) => {
let num_axis = input.rank() as i64;
let axis = if axis >= 0 {
axis as usize
} else if axis < -num_axis {
bail!("wrong axis in concat {axis} for shape {:?}", input.shape())
} else {
(num_axis - axis) as usize
};
candle_nn::ops::softmax(input, axis)?
}
};
values.insert(node.output[0].clone(), output);
}
"Concat" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Concat
let inputs = node
.input
.iter()
.map(|n| Ok(get(n.as_str())?.clone()))
.collect::<Result<Vec<Value>>>()?;
let axis = get_attr_i("axis")?;
let num_axis = if inputs.is_empty() {
bail!("empty concat")
} else {
inputs[0].rank() as i64
};
let axis = if axis >= 0 {
axis as usize
} else if axis < -num_axis {
bail!(
"wrong axis in concat {axis} for shape {:?}",
inputs[0].shape()
)
} else {
(num_axis - axis) as usize
};
let output = Tensor::cat(&inputs, axis)?;
values.insert(node.output[0].clone(), output);
}
"Abs" => {
let input = get(&node.input[0])?;
let output = input.abs()?;
values.insert(node.output[0].clone(), output);
}
"Cos" => {
let input = get(&node.input[0])?;
let output = input.cos()?;
values.insert(node.output[0].clone(), output);
}
"Sin" => {
let input = get(&node.input[0])?;
let output = input.sin()?;
values.insert(node.output[0].clone(), output);
}
"Neg" => {
let input = get(&node.input[0])?;
let output = input.neg()?;
values.insert(node.output[0].clone(), output);
}
"Erf" => {
let input = get(&node.input[0])?;
let output = input.erf()?;
values.insert(node.output[0].clone(), output);
}
"Tanh" => {
let input = get(&node.input[0])?;
let output = input.tanh()?;
values.insert(node.output[0].clone(), output);
}
"Sigmoid" => {
let input = get(&node.input[0])?;
let output = candle_nn::ops::sigmoid(input)?;
values.insert(node.output[0].clone(), output);
}
"Gelu" => { "Gelu" => {
let input = get(&node.input[0])?; let input = get(&node.input[0])?;
let output = input.gelu_erf()?; let output = input.gelu_erf()?;
@ -79,49 +214,20 @@ pub fn simple_eval(
}; };
let output = match value.r#type() { let output = match value.r#type() {
AttributeType::Tensor => { AttributeType::Tensor => {
use crate::onnx::tensor_proto::DataType;
let t = value.t.as_ref().unwrap(); let t = value.t.as_ref().unwrap();
let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect(); let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect();
match DataType::try_from(t.data_type) { match DataType::try_from(t.data_type) {
Ok(DataType::Uint8) => Tensor::from_raw_buffer( Ok(dt) => match dtype(dt) {
t.raw_data.as_slice(), Some(dt) => Tensor::from_raw_buffer(
DType::U8, t.raw_data.as_slice(),
dims.as_slice(), dt,
&Device::Cpu, dims.as_slice(),
)?, &Device::Cpu,
Ok(DataType::Uint32) => Tensor::from_raw_buffer( )?,
t.raw_data.as_slice(), None => {
DType::U32, bail!("unsupported 'value' data-type {dt:?} for {}", node.name)
dims.as_slice(), }
&Device::Cpu, },
)?,
Ok(DataType::Int64) => Tensor::from_raw_buffer(
t.raw_data.as_slice(),
DType::I64,
dims.as_slice(),
&Device::Cpu,
)?,
Ok(DataType::Float16) => Tensor::from_raw_buffer(
t.raw_data.as_slice(),
DType::F16,
dims.as_slice(),
&Device::Cpu,
)?,
Ok(DataType::Float) => Tensor::from_raw_buffer(
t.raw_data.as_slice(),
DType::F32,
dims.as_slice(),
&Device::Cpu,
)?,
Ok(DataType::Double) => Tensor::from_raw_buffer(
t.raw_data.as_slice(),
DType::F64,
dims.as_slice(),
&Device::Cpu,
)?,
Ok(dt) => {
bail!("unsupported 'value' data-type {dt:?} for {}", node.name)
}
Err(_) => { Err(_) => {
bail!( bail!(
"unsupported 'value' data-type {} for {}", "unsupported 'value' data-type {} for {}",
@ -138,15 +244,17 @@ pub fn simple_eval(
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast
"Cast" => { "Cast" => {
let input = get(&node.input[0])?; let input = get(&node.input[0])?;
let dtype = match node.attribute.iter().find(|attr| attr.name == "to") { let dt = get_attr_i("to")?;
None => { let dtype = match DataType::try_from(dt as i32) {
bail!("cannot find the 'to' attribute in 'Cast' for {}", node.name) Ok(dt) => match dtype(dt) {
} Some(dt) => dt,
Some(dtype) => match dtype.r#type() { None => {
AttributeType::Floats => candle::DType::F32, bail!("unsupported 'to' value {dt:?} for cast {}", node.name)
AttributeType::Int => candle::DType::I64, }
rtype => bail!("unsupported 'to' type {rtype:?} for {}", node.name),
}, },
Err(_) => {
bail!("unsupported 'to' value {dt:?} for cast {}", node.name)
}
}; };
let output = input.to_dtype(dtype)?; let output = input.to_dtype(dtype)?;
values.insert(node.output[0].clone(), output); values.insert(node.output[0].clone(), output);

View File

@ -6,7 +6,7 @@ pub mod onnx {
} }
mod eval; mod eval;
pub use eval::simple_eval; pub use eval::{dtype, simple_eval};
pub fn read_file<P: AsRef<std::path::Path>>(p: P) -> Result<onnx::ModelProto> { pub fn read_file<P: AsRef<std::path::Path>>(p: P) -> Result<onnx::ModelProto> {
let buf = std::fs::read(p)?; let buf = std::fs::read(p)?;