mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Add missing onnx operations (#2096)
* Add missing onnx operations * Add tests and fix errors * Run rustfmt
This commit is contained in:
@ -23,6 +23,11 @@ trait Attr {
|
|||||||
fn get(attr: &onnx::AttributeProto) -> Result<&Self>;
|
fn get(attr: &onnx::AttributeProto) -> Result<&Self>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
trait AttrOwned: Sized {
|
||||||
|
const TYPE: AttributeType;
|
||||||
|
fn get(attr: &onnx::AttributeProto) -> Result<Self>;
|
||||||
|
}
|
||||||
|
|
||||||
impl Attr for i64 {
|
impl Attr for i64 {
|
||||||
const TYPE: AttributeType = AttributeType::Int;
|
const TYPE: AttributeType = AttributeType::Int;
|
||||||
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
|
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
|
||||||
@ -51,6 +56,50 @@ impl Attr for str {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl AttrOwned for Tensor {
|
||||||
|
const TYPE: AttributeType = AttributeType::Tensor;
|
||||||
|
fn get(attr: &onnx::AttributeProto) -> Result<Self> {
|
||||||
|
let tensor_proto = match &attr.t {
|
||||||
|
Some(value) => value,
|
||||||
|
None => bail!(
|
||||||
|
"attribute {} was of type TENSOR, but no tensor was found",
|
||||||
|
attr.name
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
|
let data_type = match DataType::try_from(tensor_proto.data_type) {
|
||||||
|
Ok(value) => value,
|
||||||
|
Err(_) => bail!(
|
||||||
|
"attribute {} of type TENSOR was an invalid data_type number {}",
|
||||||
|
attr.name,
|
||||||
|
tensor_proto.data_type
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
|
let dtype = match dtype(data_type) {
|
||||||
|
Some(value) => value,
|
||||||
|
None => bail!(
|
||||||
|
"attribute {} of type TENSOR has an unsupported data_type {}",
|
||||||
|
attr.name,
|
||||||
|
data_type.as_str_name()
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut dims = Vec::with_capacity(tensor_proto.dims.len());
|
||||||
|
for dim in &tensor_proto.dims {
|
||||||
|
if dim < &0 {
|
||||||
|
bail!(
|
||||||
|
"attribute {} of type TENSOR has a negative dimension, which is unsupported",
|
||||||
|
attr.name
|
||||||
|
)
|
||||||
|
}
|
||||||
|
dims.push(*dim as usize)
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor::from_raw_buffer(&tensor_proto.raw_data, dtype, &dims, &Device::Cpu)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn get_attr_<'a>(node: &'a onnx::NodeProto, name: &str) -> Result<&'a onnx::AttributeProto> {
|
fn get_attr_<'a>(node: &'a onnx::NodeProto, name: &str) -> Result<&'a onnx::AttributeProto> {
|
||||||
match node.attribute.iter().find(|attr| attr.name == name) {
|
match node.attribute.iter().find(|attr| attr.name == name) {
|
||||||
None => {
|
None => {
|
||||||
@ -98,6 +147,24 @@ fn get_attr_opt<'a, T: Attr + ?Sized>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_attr_opt_owned<T: AttrOwned>(node: &onnx::NodeProto, name: &str) -> Result<Option<T>> {
|
||||||
|
match node.attribute.iter().find(|attr| attr.name == name) {
|
||||||
|
None => Ok(None),
|
||||||
|
Some(attr) => {
|
||||||
|
if attr.r#type() != T::TYPE {
|
||||||
|
bail!(
|
||||||
|
"unsupported type {:?} for '{name}' attribute in '{}' for {}",
|
||||||
|
attr.r#type,
|
||||||
|
node.op_type,
|
||||||
|
node.name
|
||||||
|
)
|
||||||
|
}
|
||||||
|
let val = T::get(attr)?;
|
||||||
|
Ok(Some(val))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
|
pub fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
|
||||||
let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect();
|
let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect();
|
||||||
match DataType::try_from(t.data_type) {
|
match DataType::try_from(t.data_type) {
|
||||||
@ -458,14 +525,17 @@ pub fn simple_eval(
|
|||||||
}
|
}
|
||||||
values.insert(node.output[0].clone(), xs);
|
values.insert(node.output[0].clone(), xs);
|
||||||
}
|
}
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#ConstantOfShape
|
||||||
"ConstantOfShape" => {
|
"ConstantOfShape" => {
|
||||||
let dims = get(&node.input[0])?;
|
let input = get(&node.input[0])?;
|
||||||
let shape = dims
|
let value = get_attr_opt_owned::<Tensor>(node, "value")?.unwrap_or(Tensor::zeros(
|
||||||
.to_vec1::<i64>()?
|
(),
|
||||||
.into_iter()
|
DType::F32,
|
||||||
.map(|v| v as usize)
|
&Device::Cpu,
|
||||||
.collect::<Vec<_>>();
|
)?);
|
||||||
let xs = Tensor::zeros(shape, DType::F32, dims.device())?;
|
|
||||||
|
let xs = Tensor::ones(input.shape(), value.dtype(), input.device())?
|
||||||
|
.broadcast_mul(&value)?;
|
||||||
values.insert(node.output[0].clone(), xs);
|
values.insert(node.output[0].clone(), xs);
|
||||||
}
|
}
|
||||||
"Unsqueeze" => {
|
"Unsqueeze" => {
|
||||||
@ -552,6 +622,82 @@ pub fn simple_eval(
|
|||||||
let dims = Tensor::from_vec(dims, xs.rank(), xs.device())?;
|
let dims = Tensor::from_vec(dims, xs.rank(), xs.device())?;
|
||||||
values.insert(node.output[0].clone(), dims);
|
values.insert(node.output[0].clone(), dims);
|
||||||
}
|
}
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sqrt
|
||||||
|
"Sqrt" => {
|
||||||
|
let xs = get(&node.input[0])?;
|
||||||
|
let output = xs.sqrt()?;
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Range
|
||||||
|
"Range" => {
|
||||||
|
let start = get(&node.input[0])?;
|
||||||
|
let limit = get(&node.input[1])?;
|
||||||
|
let delta = get(&node.input[2])?;
|
||||||
|
|
||||||
|
macro_rules! arange_step {
|
||||||
|
($t: ty) => {
|
||||||
|
Tensor::arange_step(
|
||||||
|
start.to_vec0::<$t>()?,
|
||||||
|
limit.to_vec0::<$t>()?,
|
||||||
|
delta.to_vec0::<$t>()?,
|
||||||
|
&Device::Cpu,
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
let output = match start.dtype() {
|
||||||
|
DType::U8 => arange_step!(u8),
|
||||||
|
DType::U32 => arange_step!(u32),
|
||||||
|
DType::I64 => arange_step!(i64),
|
||||||
|
DType::BF16 => arange_step!(f32),
|
||||||
|
DType::F16 => arange_step!(f32),
|
||||||
|
DType::F32 => arange_step!(f32),
|
||||||
|
DType::F64 => arange_step!(f64),
|
||||||
|
};
|
||||||
|
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Greater
|
||||||
|
"Greater" => {
|
||||||
|
let a = get(&node.input[0])?;
|
||||||
|
let b = get(&node.input[1])?;
|
||||||
|
|
||||||
|
let output = a.broadcast_gt(b)?;
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Less
|
||||||
|
"Less" => {
|
||||||
|
let a = get(&node.input[0])?;
|
||||||
|
let b = get(&node.input[1])?;
|
||||||
|
|
||||||
|
let output = a.broadcast_lt(b)?;
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Log
|
||||||
|
"Log" => {
|
||||||
|
let a = get(&node.input[0])?;
|
||||||
|
|
||||||
|
let output = a.log()?;
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Min
|
||||||
|
"Min" => {
|
||||||
|
let mut output = get(&node.input[0])?.clone();
|
||||||
|
for input in node.input.iter() {
|
||||||
|
let input = get(input)?;
|
||||||
|
output = output.broadcast_minimum(input)?
|
||||||
|
}
|
||||||
|
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Where
|
||||||
|
"Where" => {
|
||||||
|
let cond = get(&node.input[0])?;
|
||||||
|
let a = get(&node.input[1])?;
|
||||||
|
let b = get(&node.input[2])?;
|
||||||
|
let output = cond.where_cond(a, b)?;
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
"Conv" => {
|
"Conv" => {
|
||||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
|
||||||
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
|
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
|
||||||
|
@ -4,12 +4,16 @@ extern crate intel_mkl_src;
|
|||||||
#[cfg(feature = "accelerate")]
|
#[cfg(feature = "accelerate")]
|
||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
use candle::{Device, NdArray, Result, Tensor};
|
use candle::{DType, Device, NdArray, Result, Tensor};
|
||||||
|
use candle_onnx::onnx;
|
||||||
|
use candle_onnx::onnx::attribute_proto::AttributeType;
|
||||||
|
use candle_onnx::onnx::tensor_proto::DataType;
|
||||||
use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto};
|
use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
const INPUT_X: &str = "x";
|
const INPUT_X: &str = "x";
|
||||||
const INPUT_Y: &str = "y";
|
const INPUT_Y: &str = "y";
|
||||||
|
const INPUT_A: &str = "a";
|
||||||
const OUTPUT_Z: &str = "z";
|
const OUTPUT_Z: &str = "z";
|
||||||
|
|
||||||
fn create_model_proto_with_graph(graph: Option<GraphProto>) -> ModelProto {
|
fn create_model_proto_with_graph(graph: Option<GraphProto>) -> ModelProto {
|
||||||
@ -820,7 +824,137 @@ fn test_flatten_operation() -> Result<()> {
|
|||||||
// #[test]
|
// #[test]
|
||||||
|
|
||||||
// "ConstantOfShape"
|
// "ConstantOfShape"
|
||||||
// #[test]
|
#[test]
|
||||||
|
fn test_constant_of_shape() -> Result<()> {
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-31
|
||||||
|
test(&[4i64, 3, 2], Some(1.), &[1., 1., 1.])?;
|
||||||
|
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-31
|
||||||
|
test(&[0.], Some(0i64), &[0i64])?;
|
||||||
|
|
||||||
|
// "value" defaults to 0 f32
|
||||||
|
test(&[1i64, 2, 3, 4], None as Option<i64>, &[0., 0., 0., 0.])?;
|
||||||
|
|
||||||
|
fn test(
|
||||||
|
input: impl NdArray,
|
||||||
|
value: Option<impl NdArray>,
|
||||||
|
expected: impl NdArray,
|
||||||
|
) -> Result<()> {
|
||||||
|
let mut attribute = vec![];
|
||||||
|
|
||||||
|
if let Some(value) = value {
|
||||||
|
let tensor = Tensor::new(value, &Device::Cpu)?;
|
||||||
|
|
||||||
|
let (value, data_type) = match tensor.dtype() {
|
||||||
|
DType::U8 => (
|
||||||
|
tensor.to_vec0::<u8>()?.to_le_bytes().to_vec(),
|
||||||
|
DataType::Uint8,
|
||||||
|
),
|
||||||
|
DType::U32 => (
|
||||||
|
tensor.to_vec0::<u32>()?.to_le_bytes().to_vec(),
|
||||||
|
DataType::Uint32,
|
||||||
|
),
|
||||||
|
DType::I64 => (
|
||||||
|
tensor.to_vec0::<i64>()?.to_le_bytes().to_vec(),
|
||||||
|
DataType::Int64,
|
||||||
|
),
|
||||||
|
DType::F32 => (
|
||||||
|
tensor.to_vec0::<f32>()?.to_le_bytes().to_vec(),
|
||||||
|
DataType::Float,
|
||||||
|
),
|
||||||
|
DType::F64 => (
|
||||||
|
tensor.to_vec0::<f64>()?.to_le_bytes().to_vec(),
|
||||||
|
DataType::Double,
|
||||||
|
),
|
||||||
|
_ => panic!("unsupported DType in test"),
|
||||||
|
};
|
||||||
|
let tensor = onnx::TensorProto {
|
||||||
|
data_type: data_type.into(),
|
||||||
|
dims: tensor.dims().iter().map(|v| *v as i64).collect(),
|
||||||
|
raw_data: value,
|
||||||
|
segment: None,
|
||||||
|
float_data: vec![],
|
||||||
|
int32_data: vec![],
|
||||||
|
string_data: vec![],
|
||||||
|
int64_data: vec![],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
external_data: vec![],
|
||||||
|
data_location: 0,
|
||||||
|
double_data: vec![],
|
||||||
|
uint64_data: vec![],
|
||||||
|
};
|
||||||
|
|
||||||
|
attribute.push(AttributeProto {
|
||||||
|
name: "value".to_string(),
|
||||||
|
ref_attr_name: "value".to_string(),
|
||||||
|
i: 0,
|
||||||
|
doc_string: "value".to_string(),
|
||||||
|
r#type: AttributeType::Tensor.into(),
|
||||||
|
f: 0.0,
|
||||||
|
s: vec![],
|
||||||
|
t: Some(tensor),
|
||||||
|
g: None,
|
||||||
|
sparse_tensor: None,
|
||||||
|
tp: None,
|
||||||
|
floats: vec![],
|
||||||
|
ints: vec![],
|
||||||
|
strings: vec![],
|
||||||
|
tensors: vec![],
|
||||||
|
graphs: vec![],
|
||||||
|
sparse_tensors: vec![],
|
||||||
|
type_protos: vec![],
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "ConstantOfShape".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute,
|
||||||
|
input: vec![INPUT_X.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), Tensor::new(input, &Device::Cpu)?);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval
|
||||||
|
.get(OUTPUT_Z)
|
||||||
|
.expect("Output 'z' not found")
|
||||||
|
.to_dtype(DType::F64)?;
|
||||||
|
|
||||||
|
let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?;
|
||||||
|
match expected.dims().len() {
|
||||||
|
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
|
||||||
|
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
|
||||||
|
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
|
||||||
|
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
// "Unsqueeze"
|
// "Unsqueeze"
|
||||||
// #[test]
|
// #[test]
|
||||||
@ -1639,3 +1773,450 @@ fn test_reduce_mean() -> Result<()> {
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// "Sqrt"
|
||||||
|
#[test]
|
||||||
|
fn test_sqrt() -> Result<()> {
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-155
|
||||||
|
test(&[1., 4., 9.], &[1., 2., 3.])?;
|
||||||
|
|
||||||
|
fn test(data: impl NdArray, expected: impl NdArray) -> Result<()> {
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Sqrt".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![],
|
||||||
|
input: vec![INPUT_X.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
|
||||||
|
let expected = Tensor::new(expected, &Device::Cpu)?;
|
||||||
|
match expected.dims().len() {
|
||||||
|
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
|
||||||
|
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
|
||||||
|
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
|
||||||
|
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// "Range"
|
||||||
|
#[test]
|
||||||
|
fn test_range() -> Result<()> {
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-113
|
||||||
|
test(1., 5., 2., &[1., 3.])?;
|
||||||
|
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-113
|
||||||
|
test(10i64, 6i64, -3i64, &[10i64, 7i64])?;
|
||||||
|
|
||||||
|
fn test(
|
||||||
|
start: impl NdArray,
|
||||||
|
limit: impl NdArray,
|
||||||
|
delta: impl NdArray,
|
||||||
|
expected: impl NdArray,
|
||||||
|
) -> Result<()> {
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Range".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![],
|
||||||
|
input: vec![
|
||||||
|
INPUT_X.to_string(),
|
||||||
|
INPUT_Y.to_string(),
|
||||||
|
INPUT_A.to_string(),
|
||||||
|
],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), Tensor::new(start, &Device::Cpu)?);
|
||||||
|
inputs.insert(INPUT_Y.to_string(), Tensor::new(limit, &Device::Cpu)?);
|
||||||
|
inputs.insert(INPUT_A.to_string(), Tensor::new(delta, &Device::Cpu)?);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval
|
||||||
|
.get(OUTPUT_Z)
|
||||||
|
.expect("Output 'z' not found")
|
||||||
|
.to_dtype(DType::F64)?;
|
||||||
|
|
||||||
|
let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?;
|
||||||
|
match expected.dims().len() {
|
||||||
|
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
|
||||||
|
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
|
||||||
|
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
|
||||||
|
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// "Greater"
|
||||||
|
#[test]
|
||||||
|
fn test_greater() -> Result<()> {
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-63
|
||||||
|
test(&[1., 2., 3.], &[3., 2., 1.], &[0u8, 0, 1])?;
|
||||||
|
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-63
|
||||||
|
test(&[1., 2., 3.], 2., &[0u8, 0, 1])?;
|
||||||
|
|
||||||
|
fn test(a: impl NdArray, b: impl NdArray, expected: impl NdArray) -> Result<()> {
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Greater".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![],
|
||||||
|
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), Tensor::new(a, &Device::Cpu)?);
|
||||||
|
inputs.insert(INPUT_Y.to_string(), Tensor::new(b, &Device::Cpu)?);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval
|
||||||
|
.get(OUTPUT_Z)
|
||||||
|
.expect("Output 'z' not found")
|
||||||
|
.to_dtype(DType::F64)?;
|
||||||
|
|
||||||
|
let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?;
|
||||||
|
match expected.dims().len() {
|
||||||
|
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
|
||||||
|
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
|
||||||
|
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
|
||||||
|
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// "Less"
|
||||||
|
#[test]
|
||||||
|
fn test_less() -> Result<()> {
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-81
|
||||||
|
test(&[1., 2., 3.], &[3., 2., 1.], &[1u8, 0, 0])?;
|
||||||
|
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-81
|
||||||
|
test(&[1., 2., 3.], 2., &[1u8, 0, 0])?;
|
||||||
|
|
||||||
|
fn test(a: impl NdArray, b: impl NdArray, expected: impl NdArray) -> Result<()> {
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Less".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![],
|
||||||
|
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), Tensor::new(a, &Device::Cpu)?);
|
||||||
|
inputs.insert(INPUT_Y.to_string(), Tensor::new(b, &Device::Cpu)?);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval
|
||||||
|
.get(OUTPUT_Z)
|
||||||
|
.expect("Output 'z' not found")
|
||||||
|
.to_dtype(DType::F64)?;
|
||||||
|
|
||||||
|
let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?;
|
||||||
|
match expected.dims().len() {
|
||||||
|
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
|
||||||
|
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
|
||||||
|
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
|
||||||
|
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// "Log"
|
||||||
|
#[test]
|
||||||
|
fn test_log() -> Result<()> {
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-82
|
||||||
|
test(&[1., 10.], &[0., std::f64::consts::LN_10])?;
|
||||||
|
|
||||||
|
fn test(data: impl NdArray, expected: impl NdArray) -> Result<()> {
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Log".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![],
|
||||||
|
input: vec![INPUT_X.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
|
||||||
|
let expected = Tensor::new(expected, &Device::Cpu)?;
|
||||||
|
match expected.dims().len() {
|
||||||
|
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
|
||||||
|
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
|
||||||
|
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
|
||||||
|
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// "Min"
|
||||||
|
#[test]
|
||||||
|
fn test_min() -> Result<()> {
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-94
|
||||||
|
test(&[3., 2., 1.], &[1., 4., 4.], &[2., 5., 0.], &[1., 2., 0.])?;
|
||||||
|
|
||||||
|
fn test(
|
||||||
|
a: impl NdArray,
|
||||||
|
b: impl NdArray,
|
||||||
|
c: impl NdArray,
|
||||||
|
expected: impl NdArray,
|
||||||
|
) -> Result<()> {
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Min".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![],
|
||||||
|
input: vec![
|
||||||
|
INPUT_X.to_string(),
|
||||||
|
INPUT_Y.to_string(),
|
||||||
|
INPUT_A.to_string(),
|
||||||
|
],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), Tensor::new(a, &Device::Cpu)?);
|
||||||
|
inputs.insert(INPUT_Y.to_string(), Tensor::new(b, &Device::Cpu)?);
|
||||||
|
inputs.insert(INPUT_A.to_string(), Tensor::new(c, &Device::Cpu)?);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
|
||||||
|
let expected = Tensor::new(expected, &Device::Cpu)?;
|
||||||
|
match expected.dims().len() {
|
||||||
|
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
|
||||||
|
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
|
||||||
|
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
|
||||||
|
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// "Where"
|
||||||
|
#[test]
|
||||||
|
fn test_where() -> Result<()> {
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-173
|
||||||
|
test(
|
||||||
|
&[[1u8, 0], [1, 1]],
|
||||||
|
&[[1i64, 2], [3, 4]],
|
||||||
|
&[[9i64, 8], [7, 6]],
|
||||||
|
&[[1i64, 8], [3, 4]],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-173
|
||||||
|
test(
|
||||||
|
&[[1u8, 0], [1, 1]],
|
||||||
|
&[[1., 2.], [3., 4.]],
|
||||||
|
&[[9., 8.], [7., 6.]],
|
||||||
|
&[[1., 8.], [3., 4.]],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
fn test(
|
||||||
|
condition: impl NdArray,
|
||||||
|
x: impl NdArray,
|
||||||
|
y: impl NdArray,
|
||||||
|
expected: impl NdArray,
|
||||||
|
) -> Result<()> {
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Where".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![],
|
||||||
|
input: vec![
|
||||||
|
INPUT_X.to_string(),
|
||||||
|
INPUT_Y.to_string(),
|
||||||
|
INPUT_A.to_string(),
|
||||||
|
],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), Tensor::new(condition, &Device::Cpu)?);
|
||||||
|
inputs.insert(INPUT_Y.to_string(), Tensor::new(x, &Device::Cpu)?);
|
||||||
|
inputs.insert(INPUT_A.to_string(), Tensor::new(y, &Device::Cpu)?);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval
|
||||||
|
.get(OUTPUT_Z)
|
||||||
|
.expect("Output 'z' not found")
|
||||||
|
.to_dtype(DType::F64)?;
|
||||||
|
|
||||||
|
let expected = Tensor::new(expected, &Device::Cpu)?.to_dtype(DType::F64)?;
|
||||||
|
match expected.dims().len() {
|
||||||
|
0 => assert_eq!(z.to_vec0::<f64>()?, expected.to_vec0::<f64>()?),
|
||||||
|
1 => assert_eq!(z.to_vec1::<f64>()?, expected.to_vec1::<f64>()?),
|
||||||
|
2 => assert_eq!(z.to_vec2::<f64>()?, expected.to_vec2::<f64>()?),
|
||||||
|
3 => assert_eq!(z.to_vec3::<f64>()?, expected.to_vec3::<f64>()?),
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user