Add a simple Module trait and implement it for the various nn layers (#500)

* Start adding the module trait.

* Use the module trait.

* Implement module for qmatmul.
This commit is contained in:
Laurent Mazare
2023-08-18 09:38:22 +01:00
committed by GitHub
parent 13401df4d1
commit c78ce76501
33 changed files with 70 additions and 28 deletions

View File

@ -1,5 +1,5 @@
use candle::{DType, Device, Result, Tensor};
use candle_nn::{linear, AdamW, Linear, ParamsAdamW, VarBuilder, VarMap};
use candle_nn::{linear, AdamW, Linear, Module, ParamsAdamW, VarBuilder, VarMap};
fn gen_data() -> Result<(Tensor, Tensor)> {
// Generate some sample linear data.

View File

@ -7,8 +7,8 @@ pub enum Activation {
Elu(f64),
}
impl Activation {
pub fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
impl super::Module for Activation {
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
match self {
Self::Gelu => xs.gelu(),
Self::Relu => xs.relu(),

View File

@ -35,8 +35,10 @@ impl Conv1d {
pub fn config(&self) -> &Conv1dConfig {
&self.config
}
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
impl crate::Module for Conv1d {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.conv1d(&self.weight, self.config.padding, self.config.stride)?;
match &self.bias {
None => Ok(x),
@ -84,8 +86,10 @@ impl Conv2d {
pub fn config(&self) -> &Conv2dConfig {
&self.config
}
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
impl crate::Module for Conv2d {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.conv2d(&self.weight, self.config.padding, self.config.stride)?;
match &self.bias {
None => Ok(x),

View File

@ -18,8 +18,10 @@ impl Embedding {
pub fn embeddings(&self) -> &Tensor {
&self.embeddings
}
}
pub fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
impl crate::Module for Embedding {
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
let mut final_dims = indexes.dims().to_vec();
final_dims.push(self.hidden_size);
let indexes = indexes.flatten_all()?;

View File

@ -34,8 +34,10 @@ impl GroupNorm {
num_groups,
})
}
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
impl crate::Module for GroupNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x_shape = x.dims();
if x_shape.len() <= 2 {
candle::bail!("input rank for GroupNorm should be at least 3");

View File

@ -8,7 +8,7 @@
//!
//! ```rust
//! use candle::{Tensor, Device::Cpu};
//! use candle_nn::LayerNorm;
//! use candle_nn::{LayerNorm, Module};
//! # fn main() -> candle::Result<()> {
//!
//! let w = Tensor::new(1f32, &Cpu)?;
@ -95,8 +95,10 @@ impl LayerNorm {
eps,
}
}
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
impl crate::Module for LayerNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
@ -152,8 +154,10 @@ impl RmsNorm {
pub fn into_inner(self) -> LayerNorm {
self.0
}
}
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
impl crate::Module for RmsNorm {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
self.0.forward(xs)
}
}

View File

@ -1,5 +1,5 @@
// For now this crate shares its error type with candle-core. We may introduce some separate
// error type if needed or add some specialized cases on the candle-core side.
use candle::{Result, Tensor};
pub mod activation;
pub mod conv;
pub mod embedding;
@ -21,3 +21,20 @@ pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
pub use linear::{linear, linear_no_bias, Linear};
pub use optim::{AdamW, ParamsAdamW, SGD};
pub use var_builder::{VarBuilder, VarMap};
// A simple trait defining a module with forward method using a single argument.
pub trait Module: std::fmt::Debug {
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
/// Change the module to use training mode vs eval mode.
///
/// The default implementation does nothing as this is only used for a couple modules such as
/// dropout or batch-normalization.
fn set_training(&mut self, _training: bool) {}
}
impl Module for candle::quantized::QMatMul {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
self.forward(xs)
}
}

View File

@ -7,7 +7,7 @@
//!
//! ```rust
//! use candle::{Tensor, Device::Cpu};
//! use candle_nn::Linear;
//! use candle_nn::{Linear, Module};
//! # fn main() -> candle::Result<()> {
//!
//! let w = Tensor::new(&[[1f32, 2.], [3., 4.], [5., 6.]], &Cpu)?;
@ -29,8 +29,10 @@ impl Linear {
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
Self { weight, bias }
}
}
pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
impl super::Module for Linear {
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
let w = match x.dims() {
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
_ => self.weight.t()?,

View File

@ -23,7 +23,7 @@ extern crate intel_mkl_src;
use anyhow::Result;
use candle::{Device, Tensor};
use candle_nn::GroupNorm;
use candle_nn::{GroupNorm, Module};
mod test_utils;
use test_utils::to_vec3_round;

View File

@ -3,7 +3,7 @@ extern crate intel_mkl_src;
use anyhow::Result;
use candle::{Device, Tensor};
use candle_nn::LayerNorm;
use candle_nn::{LayerNorm, Module};
#[test]
fn layer_norm() -> Result<()> {

View File

@ -6,7 +6,7 @@ use test_utils::{to_vec0_round, to_vec2_round};
use anyhow::Result;
use candle::{Device, Tensor, Var};
use candle_nn::{AdamW, Linear, ParamsAdamW, SGD};
use candle_nn::{AdamW, Linear, Module, ParamsAdamW, SGD};
#[test]
fn sgd_optim() -> Result<()> {