diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 1256a076..32de1af9 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -45,7 +45,8 @@ pub fn sigmoid(xs: &Tensor) -> Result { } pub fn leaky_relu(xs: &Tensor, negative_slope: f64) -> Result { - xs.relu()?.minimum(&(xs * negative_slope)?) + let zeros = xs.zeros_like()?; + xs.maximum(&zeros)? + xs.minimum(&zeros)? * negative_slope } pub fn dropout(xs: &Tensor, drop_p: f32) -> Result {