mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 12:20:46 +00:00
Move the common quantized-nn code to a shared module. (#1063)
This commit is contained in:
@ -1,5 +1,4 @@
|
||||
use crate::models::quantized_t5::Embedding;
|
||||
use crate::models::with_tracing::QMatMul;
|
||||
use crate::quantized_nn::{layer_norm, linear_no_bias, Embedding, Linear};
|
||||
pub use crate::quantized_var_builder::VarBuilder;
|
||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle_nn::{Activation, LayerNorm};
|
||||
@ -8,28 +7,6 @@ use std::sync::Arc;
|
||||
pub use crate::models::stable_lm::Config;
|
||||
use crate::models::stable_lm::RotaryEmbedding;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Linear {
|
||||
weight: QMatMul,
|
||||
}
|
||||
|
||||
impl Module for Linear {
|
||||
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
x.apply(&self.weight)
|
||||
}
|
||||
}
|
||||
|
||||
fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = QMatMul::new(in_dim, out_dim, vb)?;
|
||||
Ok(Linear { weight })
|
||||
}
|
||||
|
||||
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<candle_nn::LayerNorm> {
|
||||
let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
|
||||
let bias = vb.get(size, "bias")?.dequantize(vb.device())?;
|
||||
Ok(candle_nn::LayerNorm::new(weight, bias, eps))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
struct MLP {
|
||||
|
Reference in New Issue
Block a user