mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Simplify usage of the pool functions. (#662)
* Simplify usage of the pool functions. * Small tweak. * Attempt at using apply to simplify the convnet definition.
This commit is contained in:
@ -1,5 +1,3 @@
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
pub mod activation;
|
||||
pub mod batch_norm;
|
||||
pub mod conv;
|
||||
@ -28,19 +26,4 @@ pub use optim::{AdamW, ParamsAdamW, SGD};
|
||||
pub use var_builder::VarBuilder;
|
||||
pub use var_map::VarMap;
|
||||
|
||||
// A simple trait defining a module with forward method using a single argument.
|
||||
pub trait Module: std::fmt::Debug {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
||||
|
||||
/// Change the module to use training mode vs eval mode.
|
||||
///
|
||||
/// The default implementation does nothing as this is only used for a couple modules such as
|
||||
/// dropout or batch-normalization.
|
||||
fn set_training(&mut self, _training: bool) {}
|
||||
}
|
||||
|
||||
impl Module for candle::quantized::QMatMul {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
self.forward(xs)
|
||||
}
|
||||
}
|
||||
pub use candle::Module;
|
||||
|
Reference in New Issue
Block a user