use crate::models::with_tracing::QMatMul; use crate::quantized_var_builder::VarBuilder; use candle::{Module, Result, Tensor}; #[derive(Debug, Clone)] pub struct Embedding { inner: candle_nn::Embedding, span: tracing::Span, } impl Embedding { pub fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result { let embeddings = vb.get((d1, d2), "weight")?.dequantize(vb.device())?; let inner = candle_nn::Embedding::new(embeddings, d2); let span = tracing::span!(tracing::Level::TRACE, "embedding"); Ok(Self { inner, span }) } pub fn embeddings(&self) -> &Tensor { self.inner.embeddings() } } impl Module for Embedding { fn forward(&self, xs: &Tensor) -> Result { let _enter = self.span.enter(); self.inner.forward(xs) } } #[derive(Debug, Clone)] pub struct Linear { weight: QMatMul, bias: Option, } impl Linear { pub fn from_weights(weight: QMatMul, bias: Option) -> Self { Self { weight, bias } } } impl Module for Linear { fn forward(&self, x: &Tensor) -> candle::Result { let x = x.apply(&self.weight)?; match &self.bias { None => Ok(x), Some(bias) => x.broadcast_add(bias), } } } pub fn linear(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result { let bias = vb.get(out_dim, "bias")?.dequantize(vb.device())?; let weight = QMatMul::new(in_dim, out_dim, vb)?; Ok(Linear { weight, bias: Some(bias), }) } pub fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result { let weight = vb.get(size, "weight")?.dequantize(vb.device())?; let bias = vb.get(size, "bias")?.dequantize(vb.device())?; Ok(candle_nn::LayerNorm::new(weight, bias, eps)) } pub fn layer_norm_no_bias(size: usize, eps: f64, vb: VarBuilder) -> Result { let weight = vb.get(size, "weight")?.dequantize(vb.device())?; Ok(candle_nn::LayerNorm::new_no_bias(weight, eps)) } pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result { let weight = QMatMul::new(in_dim, out_dim, vb)?; Ok(Linear { weight, bias: None }) } #[derive(Debug, Clone)] pub struct RmsNorm { inner: candle_nn::RmsNorm, span: tracing::Span, } impl RmsNorm { pub fn new(size: usize, eps: f64, vb: VarBuilder) -> Result { let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); let weight = vb.get(size, "weight")?.dequantize(vb.device())?; let inner = candle_nn::RmsNorm::new(weight, eps); Ok(Self { inner, span }) } } impl Module for RmsNorm { fn forward(&self, x: &Tensor) -> Result { let _enter = self.span.enter(); self.inner.forward(x) } }