mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Add PRelu operation (#2904)
* Add PRelu operation * Apply rustfmt. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -71,6 +71,8 @@ impl candle::Module for PReLU {
|
|||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let weight = if self.is_scalar {
|
let weight = if self.is_scalar {
|
||||||
self.weight.reshape(())?
|
self.weight.reshape(())?
|
||||||
|
} else if xs.shape() == self.weight.shape() {
|
||||||
|
self.weight.clone()
|
||||||
} else if xs.rank() >= 2 {
|
} else if xs.rank() >= 2 {
|
||||||
let num_channels = xs.dim(1)?;
|
let num_channels = xs.dim(1)?;
|
||||||
let num_weights = self.weight.elem_count();
|
let num_weights = self.weight.elem_count();
|
||||||
@ -78,7 +80,7 @@ impl candle::Module for PReLU {
|
|||||||
candle::bail!("error in prelu: unexpected number of channels for the input, got {num_channels}, weight dim is {num_weights}")
|
candle::bail!("error in prelu: unexpected number of channels for the input, got {num_channels}, weight dim is {num_weights}")
|
||||||
}
|
}
|
||||||
let mut s = vec![1; xs.rank()];
|
let mut s = vec![1; xs.rank()];
|
||||||
s[1] = self.weight.elem_count();
|
s[1] = num_weights;
|
||||||
self.weight.reshape(s)?
|
self.weight.reshape(s)?
|
||||||
} else {
|
} else {
|
||||||
self.weight.clone()
|
self.weight.clone()
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
use crate::onnx::attribute_proto::AttributeType;
|
use crate::onnx::attribute_proto::AttributeType;
|
||||||
use crate::onnx::tensor_proto::DataType;
|
use crate::onnx::tensor_proto::DataType;
|
||||||
use crate::onnx::{self, GraphProto};
|
use crate::onnx::{self, GraphProto};
|
||||||
|
use candle::Module;
|
||||||
use candle::{bail, DType, Device, Result, Tensor};
|
use candle::{bail, DType, Device, Result, Tensor};
|
||||||
|
use candle_nn::activation::PReLU;
|
||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::{HashMap, HashSet};
|
||||||
|
|
||||||
pub type Value = Tensor;
|
pub type Value = Tensor;
|
||||||
@ -991,6 +993,14 @@ fn simple_eval_(
|
|||||||
let output = input.relu()?;
|
let output = input.relu()?;
|
||||||
values.insert(node.output[0].clone(), output);
|
values.insert(node.output[0].clone(), output);
|
||||||
}
|
}
|
||||||
|
"PRelu" => {
|
||||||
|
// https://onnx.ai/onnx/operators/onnx__PRelu.html
|
||||||
|
let input = get(&node.input[0])?;
|
||||||
|
let slope = get(&node.input[1])?;
|
||||||
|
|
||||||
|
let output = PReLU::new(slope.clone(), false).forward(input)?;
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
"Ceil" => {
|
"Ceil" => {
|
||||||
let input = get(&node.input[0])?;
|
let input = get(&node.input[0])?;
|
||||||
let output = input.ceil()?;
|
let output = input.ceil()?;
|
||||||
|
@ -1846,6 +1846,64 @@ fn test_relu_operation() -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// "PRelu"
|
||||||
|
#[test]
|
||||||
|
fn test_prelu_operation() -> Result<()> {
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "PRelu".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![],
|
||||||
|
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![
|
||||||
|
ValueInfoProto {
|
||||||
|
name: INPUT_X.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
},
|
||||||
|
ValueInfoProto {
|
||||||
|
name: INPUT_Y.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
let x: Tensor = Tensor::from_vec(
|
||||||
|
vec![-1.0f32, 1.0f32, -2.0f32, 3.0f32],
|
||||||
|
&[2, 2],
|
||||||
|
&Device::Cpu,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let y: Tensor = Tensor::from_vec(vec![1.0f32, 1.1f32, 1.2f32, 1.3f32], &[2, 2], &Device::Cpu)?;
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), x);
|
||||||
|
inputs.insert(INPUT_Y.to_string(), y);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
let results = z.to_vec2::<f32>()?;
|
||||||
|
assert_eq!(results, vec![vec![-1.0, 1.0], vec![-2.4, 3.0]]);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
// "Constant"
|
// "Constant"
|
||||||
// #[test]
|
// #[test]
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user