Improve the ONNX basic example + bugfixes (#1266)

* Generate some zeros tensor in the onnx simple-eval example.

* Fix the casting operation.

* Support more ops.

* Handle reshape.

* Concat.

* Softmax.
This commit is contained in:
Laurent Mazare
2023-11-04 10:02:47 +01:00
committed by GitHub
parent f7c957d64f
commit bc9a1bf239
3 changed files with 190 additions and 52 deletions

View File

@ -41,9 +41,39 @@ pub fn main() -> Result<()> {
.unwrap()
.input
.iter()
.map(|name| {
let value = Tensor::new(&[-3.2, 2.7], &Device::Cpu)?;
Ok((name.name.clone(), value))
.map(|input| {
use candle_onnx::onnx::tensor_proto::DataType;
let type_ = input.r#type.as_ref().expect("no type for input");
let type_ = type_.value.as_ref().expect("no type.value for input");
let value = match type_ {
candle_onnx::onnx::type_proto::Value::TensorType(tt) => {
let dt = match DataType::try_from(tt.elem_type) {
Ok(dt) => match candle_onnx::dtype(dt) {
Some(dt) => dt,
None => {
anyhow::bail!(
"unsupported 'value' data-type {dt:?} for {}",
input.name
)
}
},
type_ => anyhow::bail!("unsupported input type {type_:?}"),
};
let shape = tt.shape.as_ref().expect("no tensortype.shape for input");
let dims = shape
.dim
.iter()
.map(|dim| match dim.value.as_ref().expect("no dim value") {
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize),
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimParam(_) => anyhow::bail!("DimParam is unsupported for input {}", input.name),
})
.collect::<Result<Vec<usize>>>()?;
Tensor::zeros(dims, dt, &Device::Cpu)?
}
type_ => anyhow::bail!("unsupported input type {type_:?}"),
};
Ok::<_, anyhow::Error>((input.name.clone(), value))
})
.collect::<Result<_>>()?;
let outputs = candle_onnx::simple_eval(&model, inputs)?;