Added LeakyRelu implementation

This commit is contained in:
b1rtek
2024-05-10 00:49:54 +02:00
parent 4de76b89a2
commit 91b0d526ee

View File

@ -1051,6 +1051,19 @@ pub fn simple_eval(
}.to_dtype(DType::I64)?;
values.insert(node.output[0].clone(), output);
}
"LeakyRelu" => {
let input = get(&node.input[0])?;
let dt = input.dtype();
match dt {
DType::U8 | DType::U32 | DType::I64 => {
bail!("unsupported dtype {}, only float types are allowed for LeakyRelu", dt.as_str())
}
DType::BF16 | DType::F16 | DType::F32 | DType::F64 => {}
}
let alpha = get_attr_opt::<f32>(node, "alpha")?.copied().unwrap_or(0.01);
let output = candle_nn::ops::leaky_relu(input, alpha.into())?;
values.insert(node.output[0].clone(), output);
}
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
}
}