mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Add RandomNormal ONNX operator (#2200)
This commit is contained in:
@ -2020,6 +2020,150 @@ fn test_random_uniform() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "RandomNormal"
|
||||
#[test]
|
||||
fn test_random_normal() -> Result<()> {
|
||||
test(vec![3, 2, 1, 4], None, None)?;
|
||||
test(vec![2, 2, 2, 2], Some(-10.0), None)?;
|
||||
test(vec![2, 2, 2, 2], None, Some(10.0))?;
|
||||
test(vec![1, 2, 3, 4], Some(-10.0), Some(10.0))?;
|
||||
|
||||
fn test(shape: Vec<i64>, mean: Option<f32>, scale: Option<f32>) -> Result<()> {
|
||||
let att_mean = AttributeProto {
|
||||
name: "mean".to_string(),
|
||||
ref_attr_name: "mean".to_string(),
|
||||
i: 0,
|
||||
doc_string: "mean".to_string(),
|
||||
r#type: 1, // FLOAT
|
||||
f: mean.unwrap_or(0.0),
|
||||
s: vec![],
|
||||
t: None,
|
||||
g: None,
|
||||
sparse_tensor: None,
|
||||
tp: None,
|
||||
floats: vec![],
|
||||
ints: vec![],
|
||||
strings: vec![],
|
||||
tensors: vec![],
|
||||
graphs: vec![],
|
||||
sparse_tensors: vec![],
|
||||
type_protos: vec![],
|
||||
};
|
||||
let att_scale = AttributeProto {
|
||||
name: "scale".to_string(),
|
||||
ref_attr_name: "scale".to_string(),
|
||||
i: 0,
|
||||
doc_string: "scale".to_string(),
|
||||
r#type: 1, // FLOAT
|
||||
f: scale.unwrap_or(1.0),
|
||||
s: vec![],
|
||||
t: None,
|
||||
g: None,
|
||||
sparse_tensor: None,
|
||||
tp: None,
|
||||
floats: vec![],
|
||||
ints: vec![],
|
||||
strings: vec![],
|
||||
tensors: vec![],
|
||||
graphs: vec![],
|
||||
sparse_tensors: vec![],
|
||||
type_protos: vec![],
|
||||
};
|
||||
let att_shape = AttributeProto {
|
||||
name: "shape".to_string(),
|
||||
ref_attr_name: "shape".to_string(),
|
||||
i: 0,
|
||||
doc_string: "shape".to_string(),
|
||||
r#type: 7, // INTS
|
||||
f: 0.0,
|
||||
s: vec![],
|
||||
t: None,
|
||||
g: None,
|
||||
sparse_tensor: None,
|
||||
tp: None,
|
||||
floats: vec![],
|
||||
ints: shape,
|
||||
strings: vec![],
|
||||
tensors: vec![],
|
||||
graphs: vec![],
|
||||
sparse_tensors: vec![],
|
||||
type_protos: vec![],
|
||||
};
|
||||
let att_dtype = AttributeProto {
|
||||
name: "dtype".to_string(),
|
||||
ref_attr_name: "dtype".to_string(),
|
||||
i: 11, // DOUBLE
|
||||
doc_string: "dtype".to_string(),
|
||||
r#type: 2, // INT
|
||||
f: 0.0,
|
||||
s: vec![],
|
||||
t: None,
|
||||
g: None,
|
||||
sparse_tensor: None,
|
||||
tp: None,
|
||||
floats: vec![],
|
||||
ints: vec![],
|
||||
strings: vec![],
|
||||
tensors: vec![],
|
||||
graphs: vec![],
|
||||
sparse_tensors: vec![],
|
||||
type_protos: vec![],
|
||||
};
|
||||
let attrs = {
|
||||
let mut mut_attrs = vec![att_shape, att_dtype];
|
||||
if mean.is_some() {
|
||||
mut_attrs.push(att_mean);
|
||||
}
|
||||
if scale.is_some() {
|
||||
mut_attrs.push(att_scale);
|
||||
}
|
||||
mut_attrs
|
||||
};
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "RandomNormal".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: attrs,
|
||||
input: vec![],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, HashMap::new())?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let data = z.flatten_all()?.to_vec1::<f64>()?;
|
||||
|
||||
// test if values are unique
|
||||
for (i, a) in data.iter().enumerate() {
|
||||
for (j, b) in data.iter().enumerate() {
|
||||
if i == j {
|
||||
continue;
|
||||
};
|
||||
assert_ne!(a, b);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "Range"
|
||||
#[test]
|
||||
fn test_range() -> Result<()> {
|
||||
|
Reference in New Issue
Block a user