mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
45 lines
1.3 KiB
Rust
45 lines
1.3 KiB
Rust
//! Linear layer
|
|
//!
|
|
//! This layer applies a linear transformation to the incoming data, `y = x@w.t() + b`.
|
|
//! The bias is optional. The `forward` method can be used to apply the layer, it supports input
|
|
//! with a batch dimension (so of shape `(b_sz, in_c)`) or without (of shape `(in_c,)`), the
|
|
//! output has shape `(b_sz, out_c)` and `(out_c,)` respectively.
|
|
//!
|
|
//! ```rust
|
|
//! use candle::{Tensor, Device::Cpu};
|
|
//! use candle_nn::Linear;
|
|
//! # fn main() -> candle::Result<()> {
|
|
//!
|
|
//! let w = Tensor::new(&[[1f32, 2.], [3., 4.], [5., 6.]], &Cpu)?;
|
|
//! let layer = Linear::new(w, None); // Use no bias.
|
|
//! let xs = Tensor::new(&[[10f32, 100.]], &Cpu)?;
|
|
//! let ys = layer.forward(&xs)?;
|
|
//! assert_eq!(ys.to_vec2::<f32>()?, &[[210.0, 430.0, 650.0]]);
|
|
//! # Ok(()) }
|
|
//! ```
|
|
use candle::Tensor;
|
|
|
|
#[derive(Debug)]
|
|
pub struct Linear {
|
|
weight: Tensor,
|
|
bias: Option<Tensor>,
|
|
}
|
|
|
|
impl Linear {
|
|
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
|
|
Self { weight, bias }
|
|
}
|
|
|
|
pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
|
let w = match x.dims() {
|
|
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
|
|
_ => self.weight.t()?,
|
|
};
|
|
let x = x.matmul(&w)?;
|
|
match &self.bias {
|
|
None => Ok(x),
|
|
Some(bias) => x.broadcast_add(bias),
|
|
}
|
|
}
|
|
}
|