mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Add a simple Module trait and implement it for the various nn layers (#500)
* Start adding the module trait. * Use the module trait. * Implement module for qmatmul.
This commit is contained in:
@ -8,7 +8,7 @@
|
||||
//!
|
||||
//! ```rust
|
||||
//! use candle::{Tensor, Device::Cpu};
|
||||
//! use candle_nn::LayerNorm;
|
||||
//! use candle_nn::{LayerNorm, Module};
|
||||
//! # fn main() -> candle::Result<()> {
|
||||
//!
|
||||
//! let w = Tensor::new(1f32, &Cpu)?;
|
||||
@ -95,8 +95,10 @@ impl LayerNorm {
|
||||
eps,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
impl crate::Module for LayerNorm {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x_dtype = x.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
DType::F16 | DType::BF16 => DType::F32,
|
||||
@ -152,8 +154,10 @@ impl RmsNorm {
|
||||
pub fn into_inner(self) -> LayerNorm {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
impl crate::Module for RmsNorm {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
self.0.forward(xs)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user