diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 417216d7..78e0554a 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -327,6 +327,11 @@ pub fn simple_eval( let output = input0.broadcast_pow(input1)?; values.insert(node.output[0].clone(), output); } + "Exp" => { + let xs = get(&node.input[0])?; + let output = xs.exp()?; + values.insert(node.output[0].clone(), output); + } "Equal" => { let input0 = get(&node.input[0])?; let input1 = get(&node.input[1])?; @@ -966,6 +971,46 @@ pub fn simple_eval( }; values.insert(node.output[0].clone(), output); } + "RandomUniform" => { + let dt: i64 = get_attr_opt(node, "dtype")?.copied().unwrap_or(1); // 1 is float + // type by + // default + let dtype = match DataType::try_from(dt as i32) { + Ok(dt) => match dtype(dt) { + Some(DType::U8 | DType::U32 | DType::I64) => { + bail!( + "unsupported 'dtype' value {dt:?}, only floats are allowed, for RandomUnifrom {}", + node.name + ) + } + Some(dt) => dt, + None => { + bail!( + "unsupported 'dtype' value {dt:?} for RandomUnifrom {}", + node.name + ) + } + }, + Err(_) => { + bail!( + "unsupported 'dtype' value {dt:?} for RandomUniform {}", + node.name + ) + } + }; + let low: f32 = get_attr_opt(node, "low")?.copied().unwrap_or(0.0); + let high: f32 = get_attr_opt(node, "high")?.copied().unwrap_or(1.0); + let seed: Option = get_attr_opt(node, "seed")?.copied(); + if seed.is_some() { + bail!("seed for RandomUniform is currently not supported") + }; + let shape: Vec = get_attr::<[i64]>(node, "shape")? + .iter() + .map(|x| *x as usize) + .collect(); + let output = Tensor::rand(low, high, shape, &Device::Cpu)?.to_dtype(dtype)?; + values.insert(node.output[0].clone(), output); + } op_type => bail!("unsupported op_type {op_type} for op {node:?}"), } } diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index 9b18170a..294b5511 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -231,6 +231,56 @@ fn test_div_operation() -> Result<()> { Ok(()) } +// "Exp" +#[test] +fn test_exp_operation() -> Result<()> { + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Exp".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec![INPUT_X.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + + let x = Tensor::from_vec( + vec![-1.0f32, 0.0f32, 1.0f32, 2.0f32], + &[2, 2], + &Device::Cpu, + )?; + + let mut inputs: HashMap = HashMap::new(); + inputs.insert(INPUT_X.to_string(), x); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + + let results = z.to_vec2::()?; + + assert_eq!(results[0][0], 0.36787944f32); + assert_eq!(results[0][1], 1.0f32); + assert_eq!(results[1], vec![std::f32::consts::E, 7.38905609f32]); + + Ok(()) +} + // "Equal" #[test] fn test_equal_operation() -> Result<()> { @@ -1828,6 +1878,152 @@ fn test_sqrt() -> Result<()> { Ok(()) } +// "RandomUniform" +#[test] +fn test_random_uniform() -> Result<()> { + test(vec![3, 2, 1, 4], None, None)?; + test(vec![2, 2, 2, 2], Some(-10.0), None)?; + test(vec![2, 2, 2, 2], None, Some(10.0))?; + test(vec![1, 2, 3, 4], Some(-10.0), Some(10.0))?; + + fn test(shape: Vec, low: Option, high: Option) -> Result<()> { + let att_low = AttributeProto { + name: "low".to_string(), + ref_attr_name: "low".to_string(), + i: 0, + doc_string: "low".to_string(), + r#type: 1, // FLOAT + f: low.unwrap_or(0.0), + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + let att_high = AttributeProto { + name: "high".to_string(), + ref_attr_name: "high".to_string(), + i: 0, + doc_string: "high".to_string(), + r#type: 1, // FLOAT + f: high.unwrap_or(1.0), + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + let att_shape = AttributeProto { + name: "shape".to_string(), + ref_attr_name: "shape".to_string(), + i: 0, + doc_string: "shape".to_string(), + r#type: 7, // INTS + f: 0.0, + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: shape, + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + let att_dtype = AttributeProto { + name: "dtype".to_string(), + ref_attr_name: "dtype".to_string(), + i: 11, // DOUBLE + doc_string: "dtype".to_string(), + r#type: 2, // INT + f: 0.0, + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + let attrs = { + let mut mut_attrs = vec![att_shape, att_dtype]; + if low.is_some() { + mut_attrs.push(att_low); + } + if high.is_some() { + mut_attrs.push(att_high); + } + mut_attrs + }; + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "RandomUniform".to_string(), + domain: "".to_string(), + attribute: attrs, + input: vec![], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + let eval = candle_onnx::simple_eval(&manual_graph, HashMap::new())?; + assert_eq!(eval.len(), 1); + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + let min = z + .flatten_all()? + .to_vec1()? + .into_iter() + .reduce(f64::min) + .unwrap(); + let max = z + .flatten_all()? + .to_vec1()? + .into_iter() + .reduce(f64::max) + .unwrap(); + assert!(min >= low.unwrap_or(0.0).into()); + assert!(max <= high.unwrap_or(1.0).into()); + assert_ne!(min, max); + Ok(()) + } + + Ok(()) +} + // "Range" #[test] fn test_range() -> Result<()> {