mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Add a simple Module trait and implement it for the various nn layers (#500)
* Start adding the module trait. * Use the module trait. * Implement module for qmatmul.
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
// For now this crate shares its error type with candle-core. We may introduce some separate
|
||||
// error type if needed or add some specialized cases on the candle-core side.
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
pub mod activation;
|
||||
pub mod conv;
|
||||
pub mod embedding;
|
||||
@ -21,3 +21,20 @@ pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
|
||||
pub use linear::{linear, linear_no_bias, Linear};
|
||||
pub use optim::{AdamW, ParamsAdamW, SGD};
|
||||
pub use var_builder::{VarBuilder, VarMap};
|
||||
|
||||
// A simple trait defining a module with forward method using a single argument.
|
||||
pub trait Module: std::fmt::Debug {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
|
||||
|
||||
/// Change the module to use training mode vs eval mode.
|
||||
///
|
||||
/// The default implementation does nothing as this is only used for a couple modules such as
|
||||
/// dropout or batch-normalization.
|
||||
fn set_training(&mut self, _training: bool) {}
|
||||
}
|
||||
|
||||
impl Module for candle::quantized::QMatMul {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
self.forward(xs)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user