mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
ONNX: Add Floor and Ceil (#2235)
This commit is contained in:
@ -877,6 +877,16 @@ pub fn simple_eval(
|
|||||||
let output = input.relu()?;
|
let output = input.relu()?;
|
||||||
values.insert(node.output[0].clone(), output);
|
values.insert(node.output[0].clone(), output);
|
||||||
}
|
}
|
||||||
|
"Ceil" => {
|
||||||
|
let input = get(&node.input[0])?;
|
||||||
|
let output = input.ceil()?;
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
|
"Floor" => {
|
||||||
|
let input = get(&node.input[0])?;
|
||||||
|
let output = input.floor()?;
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Constant
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Constant
|
||||||
"Constant" => {
|
"Constant" => {
|
||||||
let value = match node.attribute.iter().find(|attr| attr.name == "value") {
|
let value = match node.attribute.iter().find(|attr| attr.name == "value") {
|
||||||
|
@ -2556,3 +2556,155 @@ fn test_where() -> Result<()> {
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_floor() -> Result<()> {
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Floor".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![],
|
||||||
|
input: vec![INPUT_X.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![ValueInfoProto {
|
||||||
|
name: INPUT_X.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
let x = Tensor::from_vec(
|
||||||
|
// some values taken from https://numpy.org/doc/stable/reference/generated/numpy.floor.html
|
||||||
|
vec![
|
||||||
|
f64::NAN,
|
||||||
|
f64::INFINITY,
|
||||||
|
f64::NEG_INFINITY,
|
||||||
|
-1.7,
|
||||||
|
-1.5,
|
||||||
|
-0.2,
|
||||||
|
0.2,
|
||||||
|
1.5,
|
||||||
|
1.7,
|
||||||
|
2.0,
|
||||||
|
],
|
||||||
|
&[10],
|
||||||
|
&Device::Cpu,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), x);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
|
||||||
|
let results = z.to_vec1::<f64>()?;
|
||||||
|
|
||||||
|
assert!(results[0].is_nan());
|
||||||
|
assert_eq!(
|
||||||
|
results[1..],
|
||||||
|
vec![
|
||||||
|
f64::INFINITY,
|
||||||
|
f64::NEG_INFINITY,
|
||||||
|
-2.,
|
||||||
|
-2.,
|
||||||
|
-1.,
|
||||||
|
0.,
|
||||||
|
1.,
|
||||||
|
1.,
|
||||||
|
2.
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ceil() -> Result<()> {
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Ceil".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![],
|
||||||
|
input: vec![INPUT_X.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![ValueInfoProto {
|
||||||
|
name: INPUT_X.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
let x = Tensor::from_vec(
|
||||||
|
// some values taken from https://numpy.org/doc/stable/reference/generated/numpy.ceil.html
|
||||||
|
vec![
|
||||||
|
f64::NAN,
|
||||||
|
f64::INFINITY,
|
||||||
|
f64::NEG_INFINITY,
|
||||||
|
-1.7,
|
||||||
|
-1.5,
|
||||||
|
-0.2,
|
||||||
|
0.2,
|
||||||
|
1.5,
|
||||||
|
1.7,
|
||||||
|
2.0,
|
||||||
|
],
|
||||||
|
&[10],
|
||||||
|
&Device::Cpu,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), x);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
|
||||||
|
let results = z.to_vec1::<f64>()?;
|
||||||
|
|
||||||
|
assert!(results[0].is_nan());
|
||||||
|
assert_eq!(
|
||||||
|
results[1..],
|
||||||
|
vec![
|
||||||
|
f64::INFINITY,
|
||||||
|
f64::NEG_INFINITY,
|
||||||
|
-1.,
|
||||||
|
-1.,
|
||||||
|
-0.,
|
||||||
|
1.,
|
||||||
|
2.,
|
||||||
|
2.,
|
||||||
|
2.
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user