Add a simple Module trait and implement it for the various nn layers (#500)

* Start adding the module trait.

* Use the module trait.

* Implement module for qmatmul.
This commit is contained in:
Laurent Mazare
2023-08-18 09:38:22 +01:00
committed by GitHub
parent 13401df4d1
commit c78ce76501
33 changed files with 70 additions and 28 deletions

View File

@ -7,8 +7,8 @@ pub enum Activation {
Elu(f64),
}
impl Activation {
pub fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
impl super::Module for Activation {
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
match self {
Self::Gelu => xs.gelu(),
Self::Relu => xs.relu(),