Add a simple Module trait and implement it for the various nn layers (#500)

* Start adding the module trait.

* Use the module trait.

* Implement module for qmatmul.
This commit is contained in:
Laurent Mazare
2023-08-18 09:38:22 +01:00
committed by GitHub
parent 13401df4d1
commit c78ce76501
33 changed files with 70 additions and 28 deletions

View File

@ -8,7 +8,7 @@
//!
//! ```rust
//! use candle::{Tensor, Device::Cpu};
//! use candle_nn::LayerNorm;
//! use candle_nn::{LayerNorm, Module};
//! # fn main() -> candle::Result<()> {
//!
//! let w = Tensor::new(1f32, &Cpu)?;
@ -95,8 +95,10 @@ impl LayerNorm {
eps,
}
}
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
impl crate::Module for LayerNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
@ -152,8 +154,10 @@ impl RmsNorm {
pub fn into_inner(self) -> LayerNorm {
self.0
}
}
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
impl crate::Module for RmsNorm {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
self.0.forward(xs)
}
}