candle-onnx: add operators RandomUniform and Exp (#2116)

* Add basic RandomUniform implementation

* Use is_some to check if seed is present

* Added Exp operator implementation

---------

Co-authored-by: Mateusz Okulus <mmokulus@gmail.com>
This commit is contained in:
B1rtek
2024-04-23 19:02:19 +02:00
committed by GitHub
parent 8a05743a21
commit 6fadaf2eff
2 changed files with 241 additions and 0 deletions

View File

@ -327,6 +327,11 @@ pub fn simple_eval(
let output = input0.broadcast_pow(input1)?;
values.insert(node.output[0].clone(), output);
}
"Exp" => {
let xs = get(&node.input[0])?;
let output = xs.exp()?;
values.insert(node.output[0].clone(), output);
}
"Equal" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
@ -966,6 +971,46 @@ pub fn simple_eval(
};
values.insert(node.output[0].clone(), output);
}
"RandomUniform" => {
let dt: i64 = get_attr_opt(node, "dtype")?.copied().unwrap_or(1); // 1 is float
// type by
// default
let dtype = match DataType::try_from(dt as i32) {
Ok(dt) => match dtype(dt) {
Some(DType::U8 | DType::U32 | DType::I64) => {
bail!(
"unsupported 'dtype' value {dt:?}, only floats are allowed, for RandomUnifrom {}",
node.name
)
}
Some(dt) => dt,
None => {
bail!(
"unsupported 'dtype' value {dt:?} for RandomUnifrom {}",
node.name
)
}
},
Err(_) => {
bail!(
"unsupported 'dtype' value {dt:?} for RandomUniform {}",
node.name
)
}
};
let low: f32 = get_attr_opt(node, "low")?.copied().unwrap_or(0.0);
let high: f32 = get_attr_opt(node, "high")?.copied().unwrap_or(1.0);
let seed: Option<f32> = get_attr_opt(node, "seed")?.copied();
if seed.is_some() {
bail!("seed for RandomUniform is currently not supported")
};
let shape: Vec<usize> = get_attr::<[i64]>(node, "shape")?
.iter()
.map(|x| *x as usize)
.collect();
let output = Tensor::rand(low, high, shape, &Device::Cpu)?.to_dtype(dtype)?;
values.insert(node.output[0].clone(), output);
}
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
}
}

View File

@ -231,6 +231,56 @@ fn test_div_operation() -> Result<()> {
Ok(())
}
// "Exp"
#[test]
fn test_exp_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Exp".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
vec![-1.0f32, 0.0f32, 1.0f32, 2.0f32],
&[2, 2],
&Device::Cpu,
)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(results[0][0], 0.36787944f32);
assert_eq!(results[0][1], 1.0f32);
assert_eq!(results[1], vec![std::f32::consts::E, 7.38905609f32]);
Ok(())
}
// "Equal"
#[test]
fn test_equal_operation() -> Result<()> {
@ -1828,6 +1878,152 @@ fn test_sqrt() -> Result<()> {
Ok(())
}
// "RandomUniform"
#[test]
fn test_random_uniform() -> Result<()> {
test(vec![3, 2, 1, 4], None, None)?;
test(vec![2, 2, 2, 2], Some(-10.0), None)?;
test(vec![2, 2, 2, 2], None, Some(10.0))?;
test(vec![1, 2, 3, 4], Some(-10.0), Some(10.0))?;
fn test(shape: Vec<i64>, low: Option<f32>, high: Option<f32>) -> Result<()> {
let att_low = AttributeProto {
name: "low".to_string(),
ref_attr_name: "low".to_string(),
i: 0,
doc_string: "low".to_string(),
r#type: 1, // FLOAT
f: low.unwrap_or(0.0),
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let att_high = AttributeProto {
name: "high".to_string(),
ref_attr_name: "high".to_string(),
i: 0,
doc_string: "high".to_string(),
r#type: 1, // FLOAT
f: high.unwrap_or(1.0),
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let att_shape = AttributeProto {
name: "shape".to_string(),
ref_attr_name: "shape".to_string(),
i: 0,
doc_string: "shape".to_string(),
r#type: 7, // INTS
f: 0.0,
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: shape,
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let att_dtype = AttributeProto {
name: "dtype".to_string(),
ref_attr_name: "dtype".to_string(),
i: 11, // DOUBLE
doc_string: "dtype".to_string(),
r#type: 2, // INT
f: 0.0,
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let attrs = {
let mut mut_attrs = vec![att_shape, att_dtype];
if low.is_some() {
mut_attrs.push(att_low);
}
if high.is_some() {
mut_attrs.push(att_high);
}
mut_attrs
};
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "RandomUniform".to_string(),
domain: "".to_string(),
attribute: attrs,
input: vec![],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let eval = candle_onnx::simple_eval(&manual_graph, HashMap::new())?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let min = z
.flatten_all()?
.to_vec1()?
.into_iter()
.reduce(f64::min)
.unwrap();
let max = z
.flatten_all()?
.to_vec1()?
.into_iter()
.reduce(f64::max)
.unwrap();
assert!(min >= low.unwrap_or(0.0).into());
assert!(max <= high.unwrap_or(1.0).into());
assert_ne!(min, max);
Ok(())
}
Ok(())
}
// "Range"
#[test]
fn test_range() -> Result<()> {