mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Improve the ONNX basic example + bugfixes (#1266)
* Generate some zeros tensor in the onnx simple-eval example. * Fix the casting operation. * Support more ops. * Handle reshape. * Concat. * Softmax.
This commit is contained in:
@ -41,9 +41,39 @@ pub fn main() -> Result<()> {
|
|||||||
.unwrap()
|
.unwrap()
|
||||||
.input
|
.input
|
||||||
.iter()
|
.iter()
|
||||||
.map(|name| {
|
.map(|input| {
|
||||||
let value = Tensor::new(&[-3.2, 2.7], &Device::Cpu)?;
|
use candle_onnx::onnx::tensor_proto::DataType;
|
||||||
Ok((name.name.clone(), value))
|
|
||||||
|
let type_ = input.r#type.as_ref().expect("no type for input");
|
||||||
|
let type_ = type_.value.as_ref().expect("no type.value for input");
|
||||||
|
let value = match type_ {
|
||||||
|
candle_onnx::onnx::type_proto::Value::TensorType(tt) => {
|
||||||
|
let dt = match DataType::try_from(tt.elem_type) {
|
||||||
|
Ok(dt) => match candle_onnx::dtype(dt) {
|
||||||
|
Some(dt) => dt,
|
||||||
|
None => {
|
||||||
|
anyhow::bail!(
|
||||||
|
"unsupported 'value' data-type {dt:?} for {}",
|
||||||
|
input.name
|
||||||
|
)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
type_ => anyhow::bail!("unsupported input type {type_:?}"),
|
||||||
|
};
|
||||||
|
let shape = tt.shape.as_ref().expect("no tensortype.shape for input");
|
||||||
|
let dims = shape
|
||||||
|
.dim
|
||||||
|
.iter()
|
||||||
|
.map(|dim| match dim.value.as_ref().expect("no dim value") {
|
||||||
|
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize),
|
||||||
|
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimParam(_) => anyhow::bail!("DimParam is unsupported for input {}", input.name),
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<usize>>>()?;
|
||||||
|
Tensor::zeros(dims, dt, &Device::Cpu)?
|
||||||
|
}
|
||||||
|
type_ => anyhow::bail!("unsupported input type {type_:?}"),
|
||||||
|
};
|
||||||
|
Ok::<_, anyhow::Error>((input.name.clone(), value))
|
||||||
})
|
})
|
||||||
.collect::<Result<_>>()?;
|
.collect::<Result<_>>()?;
|
||||||
let outputs = candle_onnx::simple_eval(&model, inputs)?;
|
let outputs = candle_onnx::simple_eval(&model, inputs)?;
|
||||||
|
@ -1,9 +1,22 @@
|
|||||||
use crate::onnx;
|
use crate::onnx;
|
||||||
|
use crate::onnx::tensor_proto::DataType;
|
||||||
use candle::{bail, DType, Device, Result, Tensor};
|
use candle::{bail, DType, Device, Result, Tensor};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
pub type Value = Tensor;
|
pub type Value = Tensor;
|
||||||
|
|
||||||
|
pub fn dtype(dt: DataType) -> Option<DType> {
|
||||||
|
match dt {
|
||||||
|
DataType::Uint8 => Some(DType::U8),
|
||||||
|
DataType::Uint32 => Some(DType::U32),
|
||||||
|
DataType::Int64 => Some(DType::I64),
|
||||||
|
DataType::Float16 => Some(DType::F16),
|
||||||
|
DataType::Float => Some(DType::F32),
|
||||||
|
DataType::Double => Some(DType::F64),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// This function provides a direct evaluation of the proto.
|
// This function provides a direct evaluation of the proto.
|
||||||
// Longer-term, we should first convert the proto to an intermediate representation of the compute
|
// Longer-term, we should first convert the proto to an intermediate representation of the compute
|
||||||
// graph so as to make multiple evaluations more efficient.
|
// graph so as to make multiple evaluations more efficient.
|
||||||
@ -26,6 +39,26 @@ pub fn simple_eval(
|
|||||||
Some(value) => Ok(value),
|
Some(value) => Ok(value),
|
||||||
None => bail!("cannot find {input_name} for op {}", node.name),
|
None => bail!("cannot find {input_name} for op {}", node.name),
|
||||||
};
|
};
|
||||||
|
let get_attr_i = |name: &str| match node.attribute.iter().find(|attr| attr.name == name) {
|
||||||
|
None => {
|
||||||
|
bail!(
|
||||||
|
"cannot find the '{name}' attribute in '{}' for {}",
|
||||||
|
node.op_type,
|
||||||
|
node.name
|
||||||
|
)
|
||||||
|
}
|
||||||
|
Some(dt) => {
|
||||||
|
match dt.r#type() {
|
||||||
|
AttributeType::Int => (),
|
||||||
|
rtype => bail!(
|
||||||
|
"unsupported type {rtype:?} for '{name}' attribute in '{}' for {}",
|
||||||
|
node.op_type,
|
||||||
|
node.name
|
||||||
|
),
|
||||||
|
}
|
||||||
|
Ok(dt.i)
|
||||||
|
}
|
||||||
|
};
|
||||||
// TODO: Validate node.input for each operator.
|
// TODO: Validate node.input for each operator.
|
||||||
match node.op_type.as_str() {
|
match node.op_type.as_str() {
|
||||||
"Add" => {
|
"Add" => {
|
||||||
@ -52,12 +85,114 @@ pub fn simple_eval(
|
|||||||
let output = input0.broadcast_div(input1)?;
|
let output = input0.broadcast_div(input1)?;
|
||||||
values.insert(node.output[0].clone(), output);
|
values.insert(node.output[0].clone(), output);
|
||||||
}
|
}
|
||||||
|
"Equal" => {
|
||||||
|
let input0 = get(&node.input[0])?;
|
||||||
|
let input1 = get(&node.input[1])?;
|
||||||
|
let output = input0.eq(input1)?;
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
"MatMul" => {
|
"MatMul" => {
|
||||||
let input0 = get(&node.input[0])?;
|
let input0 = get(&node.input[0])?;
|
||||||
let input1 = get(&node.input[1])?;
|
let input1 = get(&node.input[1])?;
|
||||||
let output = input0.broadcast_matmul(input1)?;
|
let output = input0.broadcast_matmul(input1)?;
|
||||||
values.insert(node.output[0].clone(), output);
|
values.insert(node.output[0].clone(), output);
|
||||||
}
|
}
|
||||||
|
"Reshape" => {
|
||||||
|
let input0 = get(&node.input[0])?;
|
||||||
|
let input1 = get(&node.input[1])?.to_vec1::<i64>()?;
|
||||||
|
// TODO: Check that there is at most a single -1, handle other neg values.
|
||||||
|
let input1 = input1
|
||||||
|
.iter()
|
||||||
|
.map(|&v| {
|
||||||
|
if v == -1 {
|
||||||
|
input0.elem_count()
|
||||||
|
} else {
|
||||||
|
v as usize
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<usize>>();
|
||||||
|
let output = input0.reshape(input1)?;
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
|
"Softmax" => {
|
||||||
|
let input = get(&node.input[0])?;
|
||||||
|
let output = match get_attr_i("axis") {
|
||||||
|
Err(_) => candle_nn::ops::softmax_last_dim(input)?,
|
||||||
|
Ok(axis) => {
|
||||||
|
let num_axis = input.rank() as i64;
|
||||||
|
let axis = if axis >= 0 {
|
||||||
|
axis as usize
|
||||||
|
} else if axis < -num_axis {
|
||||||
|
bail!("wrong axis in concat {axis} for shape {:?}", input.shape())
|
||||||
|
} else {
|
||||||
|
(num_axis - axis) as usize
|
||||||
|
};
|
||||||
|
candle_nn::ops::softmax(input, axis)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
|
"Concat" => {
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Concat
|
||||||
|
let inputs = node
|
||||||
|
.input
|
||||||
|
.iter()
|
||||||
|
.map(|n| Ok(get(n.as_str())?.clone()))
|
||||||
|
.collect::<Result<Vec<Value>>>()?;
|
||||||
|
let axis = get_attr_i("axis")?;
|
||||||
|
let num_axis = if inputs.is_empty() {
|
||||||
|
bail!("empty concat")
|
||||||
|
} else {
|
||||||
|
inputs[0].rank() as i64
|
||||||
|
};
|
||||||
|
let axis = if axis >= 0 {
|
||||||
|
axis as usize
|
||||||
|
} else if axis < -num_axis {
|
||||||
|
bail!(
|
||||||
|
"wrong axis in concat {axis} for shape {:?}",
|
||||||
|
inputs[0].shape()
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
(num_axis - axis) as usize
|
||||||
|
};
|
||||||
|
let output = Tensor::cat(&inputs, axis)?;
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
|
"Abs" => {
|
||||||
|
let input = get(&node.input[0])?;
|
||||||
|
let output = input.abs()?;
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
|
"Cos" => {
|
||||||
|
let input = get(&node.input[0])?;
|
||||||
|
let output = input.cos()?;
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
|
"Sin" => {
|
||||||
|
let input = get(&node.input[0])?;
|
||||||
|
let output = input.sin()?;
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
|
"Neg" => {
|
||||||
|
let input = get(&node.input[0])?;
|
||||||
|
let output = input.neg()?;
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
|
"Erf" => {
|
||||||
|
let input = get(&node.input[0])?;
|
||||||
|
let output = input.erf()?;
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
|
"Tanh" => {
|
||||||
|
let input = get(&node.input[0])?;
|
||||||
|
let output = input.tanh()?;
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
|
"Sigmoid" => {
|
||||||
|
let input = get(&node.input[0])?;
|
||||||
|
let output = candle_nn::ops::sigmoid(input)?;
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
"Gelu" => {
|
"Gelu" => {
|
||||||
let input = get(&node.input[0])?;
|
let input = get(&node.input[0])?;
|
||||||
let output = input.gelu_erf()?;
|
let output = input.gelu_erf()?;
|
||||||
@ -79,49 +214,20 @@ pub fn simple_eval(
|
|||||||
};
|
};
|
||||||
let output = match value.r#type() {
|
let output = match value.r#type() {
|
||||||
AttributeType::Tensor => {
|
AttributeType::Tensor => {
|
||||||
use crate::onnx::tensor_proto::DataType;
|
|
||||||
let t = value.t.as_ref().unwrap();
|
let t = value.t.as_ref().unwrap();
|
||||||
let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect();
|
let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect();
|
||||||
match DataType::try_from(t.data_type) {
|
match DataType::try_from(t.data_type) {
|
||||||
Ok(DataType::Uint8) => Tensor::from_raw_buffer(
|
Ok(dt) => match dtype(dt) {
|
||||||
t.raw_data.as_slice(),
|
Some(dt) => Tensor::from_raw_buffer(
|
||||||
DType::U8,
|
t.raw_data.as_slice(),
|
||||||
dims.as_slice(),
|
dt,
|
||||||
&Device::Cpu,
|
dims.as_slice(),
|
||||||
)?,
|
&Device::Cpu,
|
||||||
Ok(DataType::Uint32) => Tensor::from_raw_buffer(
|
)?,
|
||||||
t.raw_data.as_slice(),
|
None => {
|
||||||
DType::U32,
|
bail!("unsupported 'value' data-type {dt:?} for {}", node.name)
|
||||||
dims.as_slice(),
|
}
|
||||||
&Device::Cpu,
|
},
|
||||||
)?,
|
|
||||||
Ok(DataType::Int64) => Tensor::from_raw_buffer(
|
|
||||||
t.raw_data.as_slice(),
|
|
||||||
DType::I64,
|
|
||||||
dims.as_slice(),
|
|
||||||
&Device::Cpu,
|
|
||||||
)?,
|
|
||||||
Ok(DataType::Float16) => Tensor::from_raw_buffer(
|
|
||||||
t.raw_data.as_slice(),
|
|
||||||
DType::F16,
|
|
||||||
dims.as_slice(),
|
|
||||||
&Device::Cpu,
|
|
||||||
)?,
|
|
||||||
Ok(DataType::Float) => Tensor::from_raw_buffer(
|
|
||||||
t.raw_data.as_slice(),
|
|
||||||
DType::F32,
|
|
||||||
dims.as_slice(),
|
|
||||||
&Device::Cpu,
|
|
||||||
)?,
|
|
||||||
Ok(DataType::Double) => Tensor::from_raw_buffer(
|
|
||||||
t.raw_data.as_slice(),
|
|
||||||
DType::F64,
|
|
||||||
dims.as_slice(),
|
|
||||||
&Device::Cpu,
|
|
||||||
)?,
|
|
||||||
Ok(dt) => {
|
|
||||||
bail!("unsupported 'value' data-type {dt:?} for {}", node.name)
|
|
||||||
}
|
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
bail!(
|
bail!(
|
||||||
"unsupported 'value' data-type {} for {}",
|
"unsupported 'value' data-type {} for {}",
|
||||||
@ -138,15 +244,17 @@ pub fn simple_eval(
|
|||||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast
|
||||||
"Cast" => {
|
"Cast" => {
|
||||||
let input = get(&node.input[0])?;
|
let input = get(&node.input[0])?;
|
||||||
let dtype = match node.attribute.iter().find(|attr| attr.name == "to") {
|
let dt = get_attr_i("to")?;
|
||||||
None => {
|
let dtype = match DataType::try_from(dt as i32) {
|
||||||
bail!("cannot find the 'to' attribute in 'Cast' for {}", node.name)
|
Ok(dt) => match dtype(dt) {
|
||||||
}
|
Some(dt) => dt,
|
||||||
Some(dtype) => match dtype.r#type() {
|
None => {
|
||||||
AttributeType::Floats => candle::DType::F32,
|
bail!("unsupported 'to' value {dt:?} for cast {}", node.name)
|
||||||
AttributeType::Int => candle::DType::I64,
|
}
|
||||||
rtype => bail!("unsupported 'to' type {rtype:?} for {}", node.name),
|
|
||||||
},
|
},
|
||||||
|
Err(_) => {
|
||||||
|
bail!("unsupported 'to' value {dt:?} for cast {}", node.name)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let output = input.to_dtype(dtype)?;
|
let output = input.to_dtype(dtype)?;
|
||||||
values.insert(node.output[0].clone(), output);
|
values.insert(node.output[0].clone(), output);
|
||||||
|
@ -6,7 +6,7 @@ pub mod onnx {
|
|||||||
}
|
}
|
||||||
|
|
||||||
mod eval;
|
mod eval;
|
||||||
pub use eval::simple_eval;
|
pub use eval::{dtype, simple_eval};
|
||||||
|
|
||||||
pub fn read_file<P: AsRef<std::path::Path>>(p: P) -> Result<onnx::ModelProto> {
|
pub fn read_file<P: AsRef<std::path::Path>>(p: P) -> Result<onnx::ModelProto> {
|
||||||
let buf = std::fs::read(p)?;
|
let buf = std::fs::read(p)?;
|
||||||
|
Reference in New Issue
Block a user