diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 75927822..e3f4d09d 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -260,6 +260,11 @@ pub fn simple_eval( let output = input0.broadcast_pow(input1)?; values.insert(node.output[0].clone(), output); } + "Exp" => { + let xs = get(&node.input[0])?; + let output = xs.exp()?; + values.insert(node.output[0].clone(), output); + } "Equal" => { let input0 = get(&node.input[0])?; let input1 = get(&node.input[1])?; diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index fda76ec2..2711d335 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -227,6 +227,56 @@ fn test_div_operation() -> Result<()> { Ok(()) } +// "Exp" +#[test] +fn test_exp_operation() -> Result<()> { + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Exp".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec![INPUT_X.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + + let x = Tensor::from_vec( + vec![-1.0f32, 0.0f32, 1.0f32, 2.0f32], + &[2, 2], + &Device::Cpu, + )?; + + let mut inputs: HashMap = HashMap::new(); + inputs.insert(INPUT_X.to_string(), x); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + + let results = z.to_vec2::()?; + + assert_eq!(results[0][0], 0.36787944f32); + assert_eq!(results[0][1], 1.0f32); + assert_eq!(results[1], vec![std::f32::consts::E, 7.38905609f32]); + + Ok(()) +} + // "Equal" #[test] fn test_equal_operation() -> Result<()> {