candle-onnx: add operators RandomUniform and Exp (#2116)

* Add basic RandomUniform implementation

* Use is_some to check if seed is present

* Added Exp operator implementation

---------

Co-authored-by: Mateusz Okulus <mmokulus@gmail.com>
This commit is contained in:
B1rtek
2024-04-23 19:02:19 +02:00
committed by GitHub
parent 8a05743a21
commit 6fadaf2eff
2 changed files with 241 additions and 0 deletions

View File

@ -327,6 +327,11 @@ pub fn simple_eval(
let output = input0.broadcast_pow(input1)?;
values.insert(node.output[0].clone(), output);
}
"Exp" => {
let xs = get(&node.input[0])?;
let output = xs.exp()?;
values.insert(node.output[0].clone(), output);
}
"Equal" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
@ -966,6 +971,46 @@ pub fn simple_eval(
};
values.insert(node.output[0].clone(), output);
}
"RandomUniform" => {
let dt: i64 = get_attr_opt(node, "dtype")?.copied().unwrap_or(1); // 1 is float
// type by
// default
let dtype = match DataType::try_from(dt as i32) {
Ok(dt) => match dtype(dt) {
Some(DType::U8 | DType::U32 | DType::I64) => {
bail!(
"unsupported 'dtype' value {dt:?}, only floats are allowed, for RandomUnifrom {}",
node.name
)
}
Some(dt) => dt,
None => {
bail!(
"unsupported 'dtype' value {dt:?} for RandomUnifrom {}",
node.name
)
}
},
Err(_) => {
bail!(
"unsupported 'dtype' value {dt:?} for RandomUniform {}",
node.name
)
}
};
let low: f32 = get_attr_opt(node, "low")?.copied().unwrap_or(0.0);
let high: f32 = get_attr_opt(node, "high")?.copied().unwrap_or(1.0);
let seed: Option<f32> = get_attr_opt(node, "seed")?.copied();
if seed.is_some() {
bail!("seed for RandomUniform is currently not supported")
};
let shape: Vec<usize> = get_attr::<[i64]>(node, "shape")?
.iter()
.map(|x| *x as usize)
.collect();
let output = Tensor::rand(low, high, shape, &Device::Cpu)?.to_dtype(dtype)?;
values.insert(node.output[0].clone(), output);
}
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
}
}