Merge branch 'refs/heads/exp' into operators-random-exp

This commit is contained in:
B1rtek
2024-04-23 17:33:05 +02:00
2 changed files with 55 additions and 0 deletions

View File

@ -327,6 +327,11 @@ pub fn simple_eval(
let output = input0.broadcast_pow(input1)?;
values.insert(node.output[0].clone(), output);
}
"Exp" => {
let xs = get(&node.input[0])?;
let output = xs.exp()?;
values.insert(node.output[0].clone(), output);
}
"Equal" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;

View File

@ -231,6 +231,56 @@ fn test_div_operation() -> Result<()> {
Ok(())
}
// "Exp"
#[test]
fn test_exp_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Exp".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
vec![-1.0f32, 0.0f32, 1.0f32, 2.0f32],
&[2, 2],
&Device::Cpu,
)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(results[0][0], 0.36787944f32);
assert_eq!(results[0][1], 1.0f32);
assert_eq!(results[1], vec![std::f32::consts::E, 7.38905609f32]);
Ok(())
}
// "Equal"
#[test]
fn test_equal_operation() -> Result<()> {