Files
candle/candle-nn/src/activation.rs
Laurent Mazare c78ce76501 Add a simple Module trait and implement it for the various nn layers (#500)
* Start adding the module trait.

* Use the module trait.

* Implement module for qmatmul.
2023-08-18 09:38:22 +01:00

19 lines
381 B
Rust

use candle::Tensor;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Activation {
Gelu,
Relu,
Elu(f64),
}
impl super::Module for Activation {
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
match self {
Self::Gelu => xs.gelu(),
Self::Relu => xs.relu(),
&Self::Elu(alpha) => xs.elu(alpha),
}
}
}