mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add a simple Module trait and implement it for the various nn layers (#500)
* Start adding the module trait. * Use the module trait. * Implement module for qmatmul.
This commit is contained in:
@ -7,7 +7,7 @@
|
||||
//!
|
||||
//! ```rust
|
||||
//! use candle::{Tensor, Device::Cpu};
|
||||
//! use candle_nn::Linear;
|
||||
//! use candle_nn::{Linear, Module};
|
||||
//! # fn main() -> candle::Result<()> {
|
||||
//!
|
||||
//! let w = Tensor::new(&[[1f32, 2.], [3., 4.], [5., 6.]], &Cpu)?;
|
||||
@ -29,8 +29,10 @@ impl Linear {
|
||||
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
|
||||
Self { weight, bias }
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
impl super::Module for Linear {
|
||||
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
let w = match x.dims() {
|
||||
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
|
||||
_ => self.weight.t()?,
|
||||
|
Reference in New Issue
Block a user