From 91b0d526eee216785932dab4de243c2a1c303148 Mon Sep 17 00:00:00 2001 From: b1rtek <53182944+B1rtek@users.noreply.github.com> Date: Fri, 10 May 2024 00:49:54 +0200 Subject: [PATCH] Added LeakyRelu implementation --- candle-onnx/src/eval.rs | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 4d3f3ee4..83118fb9 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -1051,6 +1051,19 @@ pub fn simple_eval( }.to_dtype(DType::I64)?; values.insert(node.output[0].clone(), output); } + "LeakyRelu" => { + let input = get(&node.input[0])?; + let dt = input.dtype(); + match dt { + DType::U8 | DType::U32 | DType::I64 => { + bail!("unsupported dtype {}, only float types are allowed for LeakyRelu", dt.as_str()) + } + DType::BF16 | DType::F16 | DType::F32 | DType::F64 => {} + } + let alpha = get_attr_opt::(node, "alpha")?.copied().unwrap_or(0.01); + let output = candle_nn::ops::leaky_relu(input, alpha.into())?; + values.insert(node.output[0].clone(), output); + } op_type => bail!("unsupported op_type {op_type} for op {node:?}"), } }