mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Add RandomNormal ONNX operator (#2200)
This commit is contained in:
@ -971,7 +971,7 @@ pub fn simple_eval(
|
||||
};
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"RandomUniform" => {
|
||||
random_type @ ("RandomUniform" | "RandomNormal") => {
|
||||
let dt: i64 = get_attr_opt(node, "dtype")?.copied().unwrap_or(1); // 1 is float
|
||||
// type by
|
||||
// default
|
||||
@ -979,36 +979,42 @@ pub fn simple_eval(
|
||||
Ok(dt) => match dtype(dt) {
|
||||
Some(DType::U8 | DType::U32 | DType::I64) => {
|
||||
bail!(
|
||||
"unsupported 'dtype' value {dt:?}, only floats are allowed, for RandomUnifrom {}",
|
||||
"unsupported 'dtype' value {dt:?}, only floats are allowed, for {random_type} {}",
|
||||
node.name
|
||||
)
|
||||
}
|
||||
Some(dt) => dt,
|
||||
None => {
|
||||
bail!(
|
||||
"unsupported 'dtype' value {dt:?} for RandomUnifrom {}",
|
||||
"unsupported 'dtype' value {dt:?} for {random_type} {}",
|
||||
node.name
|
||||
)
|
||||
}
|
||||
},
|
||||
Err(_) => {
|
||||
bail!(
|
||||
"unsupported 'dtype' value {dt:?} for RandomUniform {}",
|
||||
"unsupported 'dtype' value {dt:?} for {random_type} {}",
|
||||
node.name
|
||||
)
|
||||
}
|
||||
};
|
||||
let low: f32 = get_attr_opt(node, "low")?.copied().unwrap_or(0.0);
|
||||
let high: f32 = get_attr_opt(node, "high")?.copied().unwrap_or(1.0);
|
||||
let seed: Option<f32> = get_attr_opt(node, "seed")?.copied();
|
||||
if seed.is_some() {
|
||||
bail!("seed for RandomUniform is currently not supported")
|
||||
bail!("seed for {random_type} is currently not supported")
|
||||
};
|
||||
let shape: Vec<usize> = get_attr::<[i64]>(node, "shape")?
|
||||
.iter()
|
||||
.map(|x| *x as usize)
|
||||
.collect();
|
||||
let output = Tensor::rand(low, high, shape, &Device::Cpu)?.to_dtype(dtype)?;
|
||||
let output = if random_type == "RandomUniform" {
|
||||
let low: f32 = get_attr_opt(node, "low")?.copied().unwrap_or(0.0);
|
||||
let high: f32 = get_attr_opt(node, "high")?.copied().unwrap_or(1.0);
|
||||
Tensor::rand(low, high, shape, &Device::Cpu)?.to_dtype(dtype)?
|
||||
} else {
|
||||
let mean: f32 = get_attr_opt(node, "mean")?.copied().unwrap_or(0.0);
|
||||
let scale: f32 = get_attr_opt(node, "scale")?.copied().unwrap_or(1.0);
|
||||
Tensor::randn(mean, scale, shape, &Device::Cpu)?.to_dtype(dtype)?
|
||||
};
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||
|
Reference in New Issue
Block a user