mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Improve the ONNX basic example + bugfixes (#1266)
* Generate some zeros tensor in the onnx simple-eval example. * Fix the casting operation. * Support more ops. * Handle reshape. * Concat. * Softmax.
This commit is contained in:
@ -41,9 +41,39 @@ pub fn main() -> Result<()> {
|
||||
.unwrap()
|
||||
.input
|
||||
.iter()
|
||||
.map(|name| {
|
||||
let value = Tensor::new(&[-3.2, 2.7], &Device::Cpu)?;
|
||||
Ok((name.name.clone(), value))
|
||||
.map(|input| {
|
||||
use candle_onnx::onnx::tensor_proto::DataType;
|
||||
|
||||
let type_ = input.r#type.as_ref().expect("no type for input");
|
||||
let type_ = type_.value.as_ref().expect("no type.value for input");
|
||||
let value = match type_ {
|
||||
candle_onnx::onnx::type_proto::Value::TensorType(tt) => {
|
||||
let dt = match DataType::try_from(tt.elem_type) {
|
||||
Ok(dt) => match candle_onnx::dtype(dt) {
|
||||
Some(dt) => dt,
|
||||
None => {
|
||||
anyhow::bail!(
|
||||
"unsupported 'value' data-type {dt:?} for {}",
|
||||
input.name
|
||||
)
|
||||
}
|
||||
},
|
||||
type_ => anyhow::bail!("unsupported input type {type_:?}"),
|
||||
};
|
||||
let shape = tt.shape.as_ref().expect("no tensortype.shape for input");
|
||||
let dims = shape
|
||||
.dim
|
||||
.iter()
|
||||
.map(|dim| match dim.value.as_ref().expect("no dim value") {
|
||||
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimValue(v) => Ok(*v as usize),
|
||||
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimParam(_) => anyhow::bail!("DimParam is unsupported for input {}", input.name),
|
||||
})
|
||||
.collect::<Result<Vec<usize>>>()?;
|
||||
Tensor::zeros(dims, dt, &Device::Cpu)?
|
||||
}
|
||||
type_ => anyhow::bail!("unsupported input type {type_:?}"),
|
||||
};
|
||||
Ok::<_, anyhow::Error>((input.name.clone(), value))
|
||||
})
|
||||
.collect::<Result<_>>()?;
|
||||
let outputs = candle_onnx::simple_eval(&model, inputs)?;
|
||||
|
@ -1,9 +1,22 @@
|
||||
use crate::onnx;
|
||||
use crate::onnx::tensor_proto::DataType;
|
||||
use candle::{bail, DType, Device, Result, Tensor};
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub type Value = Tensor;
|
||||
|
||||
pub fn dtype(dt: DataType) -> Option<DType> {
|
||||
match dt {
|
||||
DataType::Uint8 => Some(DType::U8),
|
||||
DataType::Uint32 => Some(DType::U32),
|
||||
DataType::Int64 => Some(DType::I64),
|
||||
DataType::Float16 => Some(DType::F16),
|
||||
DataType::Float => Some(DType::F32),
|
||||
DataType::Double => Some(DType::F64),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
// This function provides a direct evaluation of the proto.
|
||||
// Longer-term, we should first convert the proto to an intermediate representation of the compute
|
||||
// graph so as to make multiple evaluations more efficient.
|
||||
@ -26,6 +39,26 @@ pub fn simple_eval(
|
||||
Some(value) => Ok(value),
|
||||
None => bail!("cannot find {input_name} for op {}", node.name),
|
||||
};
|
||||
let get_attr_i = |name: &str| match node.attribute.iter().find(|attr| attr.name == name) {
|
||||
None => {
|
||||
bail!(
|
||||
"cannot find the '{name}' attribute in '{}' for {}",
|
||||
node.op_type,
|
||||
node.name
|
||||
)
|
||||
}
|
||||
Some(dt) => {
|
||||
match dt.r#type() {
|
||||
AttributeType::Int => (),
|
||||
rtype => bail!(
|
||||
"unsupported type {rtype:?} for '{name}' attribute in '{}' for {}",
|
||||
node.op_type,
|
||||
node.name
|
||||
),
|
||||
}
|
||||
Ok(dt.i)
|
||||
}
|
||||
};
|
||||
// TODO: Validate node.input for each operator.
|
||||
match node.op_type.as_str() {
|
||||
"Add" => {
|
||||
@ -52,12 +85,114 @@ pub fn simple_eval(
|
||||
let output = input0.broadcast_div(input1)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Equal" => {
|
||||
let input0 = get(&node.input[0])?;
|
||||
let input1 = get(&node.input[1])?;
|
||||
let output = input0.eq(input1)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"MatMul" => {
|
||||
let input0 = get(&node.input[0])?;
|
||||
let input1 = get(&node.input[1])?;
|
||||
let output = input0.broadcast_matmul(input1)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Reshape" => {
|
||||
let input0 = get(&node.input[0])?;
|
||||
let input1 = get(&node.input[1])?.to_vec1::<i64>()?;
|
||||
// TODO: Check that there is at most a single -1, handle other neg values.
|
||||
let input1 = input1
|
||||
.iter()
|
||||
.map(|&v| {
|
||||
if v == -1 {
|
||||
input0.elem_count()
|
||||
} else {
|
||||
v as usize
|
||||
}
|
||||
})
|
||||
.collect::<Vec<usize>>();
|
||||
let output = input0.reshape(input1)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Softmax" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let output = match get_attr_i("axis") {
|
||||
Err(_) => candle_nn::ops::softmax_last_dim(input)?,
|
||||
Ok(axis) => {
|
||||
let num_axis = input.rank() as i64;
|
||||
let axis = if axis >= 0 {
|
||||
axis as usize
|
||||
} else if axis < -num_axis {
|
||||
bail!("wrong axis in concat {axis} for shape {:?}", input.shape())
|
||||
} else {
|
||||
(num_axis - axis) as usize
|
||||
};
|
||||
candle_nn::ops::softmax(input, axis)?
|
||||
}
|
||||
};
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Concat" => {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Concat
|
||||
let inputs = node
|
||||
.input
|
||||
.iter()
|
||||
.map(|n| Ok(get(n.as_str())?.clone()))
|
||||
.collect::<Result<Vec<Value>>>()?;
|
||||
let axis = get_attr_i("axis")?;
|
||||
let num_axis = if inputs.is_empty() {
|
||||
bail!("empty concat")
|
||||
} else {
|
||||
inputs[0].rank() as i64
|
||||
};
|
||||
let axis = if axis >= 0 {
|
||||
axis as usize
|
||||
} else if axis < -num_axis {
|
||||
bail!(
|
||||
"wrong axis in concat {axis} for shape {:?}",
|
||||
inputs[0].shape()
|
||||
)
|
||||
} else {
|
||||
(num_axis - axis) as usize
|
||||
};
|
||||
let output = Tensor::cat(&inputs, axis)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Abs" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let output = input.abs()?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Cos" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let output = input.cos()?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Sin" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let output = input.sin()?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Neg" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let output = input.neg()?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Erf" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let output = input.erf()?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Tanh" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let output = input.tanh()?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Sigmoid" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let output = candle_nn::ops::sigmoid(input)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Gelu" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let output = input.gelu_erf()?;
|
||||
@ -79,49 +214,20 @@ pub fn simple_eval(
|
||||
};
|
||||
let output = match value.r#type() {
|
||||
AttributeType::Tensor => {
|
||||
use crate::onnx::tensor_proto::DataType;
|
||||
let t = value.t.as_ref().unwrap();
|
||||
let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect();
|
||||
match DataType::try_from(t.data_type) {
|
||||
Ok(DataType::Uint8) => Tensor::from_raw_buffer(
|
||||
t.raw_data.as_slice(),
|
||||
DType::U8,
|
||||
dims.as_slice(),
|
||||
&Device::Cpu,
|
||||
)?,
|
||||
Ok(DataType::Uint32) => Tensor::from_raw_buffer(
|
||||
t.raw_data.as_slice(),
|
||||
DType::U32,
|
||||
dims.as_slice(),
|
||||
&Device::Cpu,
|
||||
)?,
|
||||
Ok(DataType::Int64) => Tensor::from_raw_buffer(
|
||||
t.raw_data.as_slice(),
|
||||
DType::I64,
|
||||
dims.as_slice(),
|
||||
&Device::Cpu,
|
||||
)?,
|
||||
Ok(DataType::Float16) => Tensor::from_raw_buffer(
|
||||
t.raw_data.as_slice(),
|
||||
DType::F16,
|
||||
dims.as_slice(),
|
||||
&Device::Cpu,
|
||||
)?,
|
||||
Ok(DataType::Float) => Tensor::from_raw_buffer(
|
||||
t.raw_data.as_slice(),
|
||||
DType::F32,
|
||||
dims.as_slice(),
|
||||
&Device::Cpu,
|
||||
)?,
|
||||
Ok(DataType::Double) => Tensor::from_raw_buffer(
|
||||
t.raw_data.as_slice(),
|
||||
DType::F64,
|
||||
dims.as_slice(),
|
||||
&Device::Cpu,
|
||||
)?,
|
||||
Ok(dt) => {
|
||||
bail!("unsupported 'value' data-type {dt:?} for {}", node.name)
|
||||
}
|
||||
Ok(dt) => match dtype(dt) {
|
||||
Some(dt) => Tensor::from_raw_buffer(
|
||||
t.raw_data.as_slice(),
|
||||
dt,
|
||||
dims.as_slice(),
|
||||
&Device::Cpu,
|
||||
)?,
|
||||
None => {
|
||||
bail!("unsupported 'value' data-type {dt:?} for {}", node.name)
|
||||
}
|
||||
},
|
||||
Err(_) => {
|
||||
bail!(
|
||||
"unsupported 'value' data-type {} for {}",
|
||||
@ -138,15 +244,17 @@ pub fn simple_eval(
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast
|
||||
"Cast" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let dtype = match node.attribute.iter().find(|attr| attr.name == "to") {
|
||||
None => {
|
||||
bail!("cannot find the 'to' attribute in 'Cast' for {}", node.name)
|
||||
}
|
||||
Some(dtype) => match dtype.r#type() {
|
||||
AttributeType::Floats => candle::DType::F32,
|
||||
AttributeType::Int => candle::DType::I64,
|
||||
rtype => bail!("unsupported 'to' type {rtype:?} for {}", node.name),
|
||||
let dt = get_attr_i("to")?;
|
||||
let dtype = match DataType::try_from(dt as i32) {
|
||||
Ok(dt) => match dtype(dt) {
|
||||
Some(dt) => dt,
|
||||
None => {
|
||||
bail!("unsupported 'to' value {dt:?} for cast {}", node.name)
|
||||
}
|
||||
},
|
||||
Err(_) => {
|
||||
bail!("unsupported 'to' value {dt:?} for cast {}", node.name)
|
||||
}
|
||||
};
|
||||
let output = input.to_dtype(dtype)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
|
@ -6,7 +6,7 @@ pub mod onnx {
|
||||
}
|
||||
|
||||
mod eval;
|
||||
pub use eval::simple_eval;
|
||||
pub use eval::{dtype, simple_eval};
|
||||
|
||||
pub fn read_file<P: AsRef<std::path::Path>>(p: P) -> Result<onnx::ModelProto> {
|
||||
let buf = std::fs::read(p)?;
|
||||
|
Reference in New Issue
Block a user