mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Merge branch 'refs/heads/exp' into operators-random-exp
This commit is contained in:
@ -327,6 +327,11 @@ pub fn simple_eval(
|
||||
let output = input0.broadcast_pow(input1)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Exp" => {
|
||||
let xs = get(&node.input[0])?;
|
||||
let output = xs.exp()?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Equal" => {
|
||||
let input0 = get(&node.input[0])?;
|
||||
let input1 = get(&node.input[1])?;
|
||||
|
@ -231,6 +231,56 @@ fn test_div_operation() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Exp"
|
||||
#[test]
|
||||
fn test_exp_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Exp".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let x = Tensor::from_vec(
|
||||
vec![-1.0f32, 0.0f32, 1.0f32, 2.0f32],
|
||||
&[2, 2],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let results = z.to_vec2::<f32>()?;
|
||||
|
||||
assert_eq!(results[0][0], 0.36787944f32);
|
||||
assert_eq!(results[0][1], 1.0f32);
|
||||
assert_eq!(results[1], vec![std::f32::consts::E, 7.38905609f32]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Equal"
|
||||
#[test]
|
||||
fn test_equal_operation() -> Result<()> {
|
||||
|
Reference in New Issue
Block a user