mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Add RandomNormal ONNX operator (#2200)
This commit is contained in:
@ -971,7 +971,7 @@ pub fn simple_eval(
|
|||||||
};
|
};
|
||||||
values.insert(node.output[0].clone(), output);
|
values.insert(node.output[0].clone(), output);
|
||||||
}
|
}
|
||||||
"RandomUniform" => {
|
random_type @ ("RandomUniform" | "RandomNormal") => {
|
||||||
let dt: i64 = get_attr_opt(node, "dtype")?.copied().unwrap_or(1); // 1 is float
|
let dt: i64 = get_attr_opt(node, "dtype")?.copied().unwrap_or(1); // 1 is float
|
||||||
// type by
|
// type by
|
||||||
// default
|
// default
|
||||||
@ -979,36 +979,42 @@ pub fn simple_eval(
|
|||||||
Ok(dt) => match dtype(dt) {
|
Ok(dt) => match dtype(dt) {
|
||||||
Some(DType::U8 | DType::U32 | DType::I64) => {
|
Some(DType::U8 | DType::U32 | DType::I64) => {
|
||||||
bail!(
|
bail!(
|
||||||
"unsupported 'dtype' value {dt:?}, only floats are allowed, for RandomUnifrom {}",
|
"unsupported 'dtype' value {dt:?}, only floats are allowed, for {random_type} {}",
|
||||||
node.name
|
node.name
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
Some(dt) => dt,
|
Some(dt) => dt,
|
||||||
None => {
|
None => {
|
||||||
bail!(
|
bail!(
|
||||||
"unsupported 'dtype' value {dt:?} for RandomUnifrom {}",
|
"unsupported 'dtype' value {dt:?} for {random_type} {}",
|
||||||
node.name
|
node.name
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
bail!(
|
bail!(
|
||||||
"unsupported 'dtype' value {dt:?} for RandomUniform {}",
|
"unsupported 'dtype' value {dt:?} for {random_type} {}",
|
||||||
node.name
|
node.name
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let low: f32 = get_attr_opt(node, "low")?.copied().unwrap_or(0.0);
|
|
||||||
let high: f32 = get_attr_opt(node, "high")?.copied().unwrap_or(1.0);
|
|
||||||
let seed: Option<f32> = get_attr_opt(node, "seed")?.copied();
|
let seed: Option<f32> = get_attr_opt(node, "seed")?.copied();
|
||||||
if seed.is_some() {
|
if seed.is_some() {
|
||||||
bail!("seed for RandomUniform is currently not supported")
|
bail!("seed for {random_type} is currently not supported")
|
||||||
};
|
};
|
||||||
let shape: Vec<usize> = get_attr::<[i64]>(node, "shape")?
|
let shape: Vec<usize> = get_attr::<[i64]>(node, "shape")?
|
||||||
.iter()
|
.iter()
|
||||||
.map(|x| *x as usize)
|
.map(|x| *x as usize)
|
||||||
.collect();
|
.collect();
|
||||||
let output = Tensor::rand(low, high, shape, &Device::Cpu)?.to_dtype(dtype)?;
|
let output = if random_type == "RandomUniform" {
|
||||||
|
let low: f32 = get_attr_opt(node, "low")?.copied().unwrap_or(0.0);
|
||||||
|
let high: f32 = get_attr_opt(node, "high")?.copied().unwrap_or(1.0);
|
||||||
|
Tensor::rand(low, high, shape, &Device::Cpu)?.to_dtype(dtype)?
|
||||||
|
} else {
|
||||||
|
let mean: f32 = get_attr_opt(node, "mean")?.copied().unwrap_or(0.0);
|
||||||
|
let scale: f32 = get_attr_opt(node, "scale")?.copied().unwrap_or(1.0);
|
||||||
|
Tensor::randn(mean, scale, shape, &Device::Cpu)?.to_dtype(dtype)?
|
||||||
|
};
|
||||||
values.insert(node.output[0].clone(), output);
|
values.insert(node.output[0].clone(), output);
|
||||||
}
|
}
|
||||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||||
|
@ -2020,6 +2020,150 @@ fn test_random_uniform() -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// "RandomNormal"
|
||||||
|
#[test]
|
||||||
|
fn test_random_normal() -> Result<()> {
|
||||||
|
test(vec![3, 2, 1, 4], None, None)?;
|
||||||
|
test(vec![2, 2, 2, 2], Some(-10.0), None)?;
|
||||||
|
test(vec![2, 2, 2, 2], None, Some(10.0))?;
|
||||||
|
test(vec![1, 2, 3, 4], Some(-10.0), Some(10.0))?;
|
||||||
|
|
||||||
|
fn test(shape: Vec<i64>, mean: Option<f32>, scale: Option<f32>) -> Result<()> {
|
||||||
|
let att_mean = AttributeProto {
|
||||||
|
name: "mean".to_string(),
|
||||||
|
ref_attr_name: "mean".to_string(),
|
||||||
|
i: 0,
|
||||||
|
doc_string: "mean".to_string(),
|
||||||
|
r#type: 1, // FLOAT
|
||||||
|
f: mean.unwrap_or(0.0),
|
||||||
|
s: vec![],
|
||||||
|
t: None,
|
||||||
|
g: None,
|
||||||
|
sparse_tensor: None,
|
||||||
|
tp: None,
|
||||||
|
floats: vec![],
|
||||||
|
ints: vec![],
|
||||||
|
strings: vec![],
|
||||||
|
tensors: vec![],
|
||||||
|
graphs: vec![],
|
||||||
|
sparse_tensors: vec![],
|
||||||
|
type_protos: vec![],
|
||||||
|
};
|
||||||
|
let att_scale = AttributeProto {
|
||||||
|
name: "scale".to_string(),
|
||||||
|
ref_attr_name: "scale".to_string(),
|
||||||
|
i: 0,
|
||||||
|
doc_string: "scale".to_string(),
|
||||||
|
r#type: 1, // FLOAT
|
||||||
|
f: scale.unwrap_or(1.0),
|
||||||
|
s: vec![],
|
||||||
|
t: None,
|
||||||
|
g: None,
|
||||||
|
sparse_tensor: None,
|
||||||
|
tp: None,
|
||||||
|
floats: vec![],
|
||||||
|
ints: vec![],
|
||||||
|
strings: vec![],
|
||||||
|
tensors: vec![],
|
||||||
|
graphs: vec![],
|
||||||
|
sparse_tensors: vec![],
|
||||||
|
type_protos: vec![],
|
||||||
|
};
|
||||||
|
let att_shape = AttributeProto {
|
||||||
|
name: "shape".to_string(),
|
||||||
|
ref_attr_name: "shape".to_string(),
|
||||||
|
i: 0,
|
||||||
|
doc_string: "shape".to_string(),
|
||||||
|
r#type: 7, // INTS
|
||||||
|
f: 0.0,
|
||||||
|
s: vec![],
|
||||||
|
t: None,
|
||||||
|
g: None,
|
||||||
|
sparse_tensor: None,
|
||||||
|
tp: None,
|
||||||
|
floats: vec![],
|
||||||
|
ints: shape,
|
||||||
|
strings: vec![],
|
||||||
|
tensors: vec![],
|
||||||
|
graphs: vec![],
|
||||||
|
sparse_tensors: vec![],
|
||||||
|
type_protos: vec![],
|
||||||
|
};
|
||||||
|
let att_dtype = AttributeProto {
|
||||||
|
name: "dtype".to_string(),
|
||||||
|
ref_attr_name: "dtype".to_string(),
|
||||||
|
i: 11, // DOUBLE
|
||||||
|
doc_string: "dtype".to_string(),
|
||||||
|
r#type: 2, // INT
|
||||||
|
f: 0.0,
|
||||||
|
s: vec![],
|
||||||
|
t: None,
|
||||||
|
g: None,
|
||||||
|
sparse_tensor: None,
|
||||||
|
tp: None,
|
||||||
|
floats: vec![],
|
||||||
|
ints: vec![],
|
||||||
|
strings: vec![],
|
||||||
|
tensors: vec![],
|
||||||
|
graphs: vec![],
|
||||||
|
sparse_tensors: vec![],
|
||||||
|
type_protos: vec![],
|
||||||
|
};
|
||||||
|
let attrs = {
|
||||||
|
let mut mut_attrs = vec![att_shape, att_dtype];
|
||||||
|
if mean.is_some() {
|
||||||
|
mut_attrs.push(att_mean);
|
||||||
|
}
|
||||||
|
if scale.is_some() {
|
||||||
|
mut_attrs.push(att_scale);
|
||||||
|
}
|
||||||
|
mut_attrs
|
||||||
|
};
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "RandomNormal".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: attrs,
|
||||||
|
input: vec![],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, HashMap::new())?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
let data = z.flatten_all()?.to_vec1::<f64>()?;
|
||||||
|
|
||||||
|
// test if values are unique
|
||||||
|
for (i, a) in data.iter().enumerate() {
|
||||||
|
for (j, b) in data.iter().enumerate() {
|
||||||
|
if i == j {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
assert_ne!(a, b);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
// "Range"
|
// "Range"
|
||||||
#[test]
|
#[test]
|
||||||
fn test_range() -> Result<()> {
|
fn test_range() -> Result<()> {
|
||||||
|
Reference in New Issue
Block a user