mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
* Support for Sign operation #2641 * Apply rustfmt. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -1944,6 +1944,12 @@ fn simple_eval_(
|
|||||||
|
|
||||||
values.insert(node.output[0].clone(), out);
|
values.insert(node.output[0].clone(), out);
|
||||||
}
|
}
|
||||||
|
// https://onnx.ai/onnx/operators/onnx__Sign.html
|
||||||
|
"Sign" => {
|
||||||
|
let input = get(&node.input[0])?;
|
||||||
|
let output = input.sign()?;
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -5869,3 +5869,44 @@ fn test_xor() -> Result<()> {
|
|||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_sign_operation() -> Result<()> {
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Sign".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![],
|
||||||
|
input: vec![INPUT_X.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(
|
||||||
|
INPUT_X.to_string(),
|
||||||
|
Tensor::new(vec![-2f32, -1., 0., 1., 2.], &Device::Cpu)?,
|
||||||
|
);
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
assert_eq!(
|
||||||
|
z.to_dtype(candle::DType::I64)?.to_vec1::<i64>()?.to_vec(),
|
||||||
|
vec![-1, -1, 0, 1, 1]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user