Simplify usage of the pool functions. (#662)

* Simplify usage of the pool functions.

* Small tweak.

* Attempt at using apply to simplify the convnet definition.
This commit is contained in:
Laurent Mazare
2023-08-29 19:12:16 +01:00
committed by GitHub
parent b31d41e26a
commit 2d3fcad267
9 changed files with 86 additions and 42 deletions

View File

@ -1,5 +1,3 @@
use candle::{Result, Tensor};
pub mod activation;
pub mod batch_norm;
pub mod conv;
@ -28,19 +26,4 @@ pub use optim::{AdamW, ParamsAdamW, SGD};
pub use var_builder::VarBuilder;
pub use var_map::VarMap;
// A simple trait defining a module with forward method using a single argument.
pub trait Module: std::fmt::Debug {
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
/// Change the module to use training mode vs eval mode.
///
/// The default implementation does nothing as this is only used for a couple modules such as
/// dropout or batch-normalization.
fn set_training(&mut self, _training: bool) {}
}
impl Module for candle::quantized::QMatMul {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
self.forward(xs)
}
}
pub use candle::Module;