Add a simple Module trait and implement it for the various nn layers (#500)

* Start adding the module trait.

* Use the module trait.

* Implement module for qmatmul.
This commit is contained in:
Laurent Mazare
2023-08-18 09:38:22 +01:00
committed by GitHub
parent 13401df4d1
commit c78ce76501
33 changed files with 70 additions and 28 deletions

View File

@ -7,7 +7,7 @@
//!
//! ```rust
//! use candle::{Tensor, Device::Cpu};
//! use candle_nn::Linear;
//! use candle_nn::{Linear, Module};
//! # fn main() -> candle::Result<()> {
//!
//! let w = Tensor::new(&[[1f32, 2.], [3., 4.], [5., 6.]], &Cpu)?;
@ -29,8 +29,10 @@ impl Linear {
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
Self { weight, bias }
}
}
pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
impl super::Module for Linear {
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
let w = match x.dims() {
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
_ => self.weight.t()?,