mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Add missing onnx operations (#2096)
* Add missing onnx operations * Add tests and fix errors * Run rustfmt
This commit is contained in:
@ -23,6 +23,11 @@ trait Attr {
|
||||
fn get(attr: &onnx::AttributeProto) -> Result<&Self>;
|
||||
}
|
||||
|
||||
trait AttrOwned: Sized {
|
||||
const TYPE: AttributeType;
|
||||
fn get(attr: &onnx::AttributeProto) -> Result<Self>;
|
||||
}
|
||||
|
||||
impl Attr for i64 {
|
||||
const TYPE: AttributeType = AttributeType::Int;
|
||||
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
|
||||
@ -51,6 +56,50 @@ impl Attr for str {
|
||||
}
|
||||
}
|
||||
|
||||
impl AttrOwned for Tensor {
|
||||
const TYPE: AttributeType = AttributeType::Tensor;
|
||||
fn get(attr: &onnx::AttributeProto) -> Result<Self> {
|
||||
let tensor_proto = match &attr.t {
|
||||
Some(value) => value,
|
||||
None => bail!(
|
||||
"attribute {} was of type TENSOR, but no tensor was found",
|
||||
attr.name
|
||||
),
|
||||
};
|
||||
|
||||
let data_type = match DataType::try_from(tensor_proto.data_type) {
|
||||
Ok(value) => value,
|
||||
Err(_) => bail!(
|
||||
"attribute {} of type TENSOR was an invalid data_type number {}",
|
||||
attr.name,
|
||||
tensor_proto.data_type
|
||||
),
|
||||
};
|
||||
|
||||
let dtype = match dtype(data_type) {
|
||||
Some(value) => value,
|
||||
None => bail!(
|
||||
"attribute {} of type TENSOR has an unsupported data_type {}",
|
||||
attr.name,
|
||||
data_type.as_str_name()
|
||||
),
|
||||
};
|
||||
|
||||
let mut dims = Vec::with_capacity(tensor_proto.dims.len());
|
||||
for dim in &tensor_proto.dims {
|
||||
if dim < &0 {
|
||||
bail!(
|
||||
"attribute {} of type TENSOR has a negative dimension, which is unsupported",
|
||||
attr.name
|
||||
)
|
||||
}
|
||||
dims.push(*dim as usize)
|
||||
}
|
||||
|
||||
Tensor::from_raw_buffer(&tensor_proto.raw_data, dtype, &dims, &Device::Cpu)
|
||||
}
|
||||
}
|
||||
|
||||
fn get_attr_<'a>(node: &'a onnx::NodeProto, name: &str) -> Result<&'a onnx::AttributeProto> {
|
||||
match node.attribute.iter().find(|attr| attr.name == name) {
|
||||
None => {
|
||||
@ -98,6 +147,24 @@ fn get_attr_opt<'a, T: Attr + ?Sized>(
|
||||
}
|
||||
}
|
||||
|
||||
fn get_attr_opt_owned<T: AttrOwned>(node: &onnx::NodeProto, name: &str) -> Result<Option<T>> {
|
||||
match node.attribute.iter().find(|attr| attr.name == name) {
|
||||
None => Ok(None),
|
||||
Some(attr) => {
|
||||
if attr.r#type() != T::TYPE {
|
||||
bail!(
|
||||
"unsupported type {:?} for '{name}' attribute in '{}' for {}",
|
||||
attr.r#type,
|
||||
node.op_type,
|
||||
node.name
|
||||
)
|
||||
}
|
||||
let val = T::get(attr)?;
|
||||
Ok(Some(val))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
|
||||
let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect();
|
||||
match DataType::try_from(t.data_type) {
|
||||
@ -458,14 +525,17 @@ pub fn simple_eval(
|
||||
}
|
||||
values.insert(node.output[0].clone(), xs);
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#ConstantOfShape
|
||||
"ConstantOfShape" => {
|
||||
let dims = get(&node.input[0])?;
|
||||
let shape = dims
|
||||
.to_vec1::<i64>()?
|
||||
.into_iter()
|
||||
.map(|v| v as usize)
|
||||
.collect::<Vec<_>>();
|
||||
let xs = Tensor::zeros(shape, DType::F32, dims.device())?;
|
||||
let input = get(&node.input[0])?;
|
||||
let value = get_attr_opt_owned::<Tensor>(node, "value")?.unwrap_or(Tensor::zeros(
|
||||
(),
|
||||
DType::F32,
|
||||
&Device::Cpu,
|
||||
)?);
|
||||
|
||||
let xs = Tensor::ones(input.shape(), value.dtype(), input.device())?
|
||||
.broadcast_mul(&value)?;
|
||||
values.insert(node.output[0].clone(), xs);
|
||||
}
|
||||
"Unsqueeze" => {
|
||||
@ -552,6 +622,82 @@ pub fn simple_eval(
|
||||
let dims = Tensor::from_vec(dims, xs.rank(), xs.device())?;
|
||||
values.insert(node.output[0].clone(), dims);
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sqrt
|
||||
"Sqrt" => {
|
||||
let xs = get(&node.input[0])?;
|
||||
let output = xs.sqrt()?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Range
|
||||
"Range" => {
|
||||
let start = get(&node.input[0])?;
|
||||
let limit = get(&node.input[1])?;
|
||||
let delta = get(&node.input[2])?;
|
||||
|
||||
macro_rules! arange_step {
|
||||
($t: ty) => {
|
||||
Tensor::arange_step(
|
||||
start.to_vec0::<$t>()?,
|
||||
limit.to_vec0::<$t>()?,
|
||||
delta.to_vec0::<$t>()?,
|
||||
&Device::Cpu,
|
||||
)?
|
||||
};
|
||||
}
|
||||
|
||||
let output = match start.dtype() {
|
||||
DType::U8 => arange_step!(u8),
|
||||
DType::U32 => arange_step!(u32),
|
||||
DType::I64 => arange_step!(i64),
|
||||
DType::BF16 => arange_step!(f32),
|
||||
DType::F16 => arange_step!(f32),
|
||||
DType::F32 => arange_step!(f32),
|
||||
DType::F64 => arange_step!(f64),
|
||||
};
|
||||
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Greater
|
||||
"Greater" => {
|
||||
let a = get(&node.input[0])?;
|
||||
let b = get(&node.input[1])?;
|
||||
|
||||
let output = a.broadcast_gt(b)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Less
|
||||
"Less" => {
|
||||
let a = get(&node.input[0])?;
|
||||
let b = get(&node.input[1])?;
|
||||
|
||||
let output = a.broadcast_lt(b)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Log
|
||||
"Log" => {
|
||||
let a = get(&node.input[0])?;
|
||||
|
||||
let output = a.log()?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Min
|
||||
"Min" => {
|
||||
let mut output = get(&node.input[0])?.clone();
|
||||
for input in node.input.iter() {
|
||||
let input = get(input)?;
|
||||
output = output.broadcast_minimum(input)?
|
||||
}
|
||||
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Where
|
||||
"Where" => {
|
||||
let cond = get(&node.input[0])?;
|
||||
let a = get(&node.input[1])?;
|
||||
let b = get(&node.input[2])?;
|
||||
let output = cond.where_cond(a, b)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Conv" => {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
|
||||
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
|
||||
|
Reference in New Issue
Block a user