mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00

* Add a quantized blip model. * Integrate the quantized blip model to the actual example.
99 lines
2.8 KiB
Rust
99 lines
2.8 KiB
Rust
use crate::models::with_tracing::QMatMul;
|
|
use crate::quantized_var_builder::VarBuilder;
|
|
use candle::{Module, Result, Tensor};
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct Embedding {
|
|
inner: candle_nn::Embedding,
|
|
span: tracing::Span,
|
|
}
|
|
|
|
impl Embedding {
|
|
pub fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
|
|
let embeddings = vb.get((d1, d2), "weight")?.dequantize(vb.device())?;
|
|
let inner = candle_nn::Embedding::new(embeddings, d2);
|
|
let span = tracing::span!(tracing::Level::TRACE, "embedding");
|
|
Ok(Self { inner, span })
|
|
}
|
|
|
|
pub fn embeddings(&self) -> &Tensor {
|
|
self.inner.embeddings()
|
|
}
|
|
}
|
|
|
|
impl Module for Embedding {
|
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
let _enter = self.span.enter();
|
|
self.inner.forward(xs)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct Linear {
|
|
weight: QMatMul,
|
|
bias: Option<Tensor>,
|
|
}
|
|
|
|
impl Linear {
|
|
pub fn from_weights(weight: QMatMul, bias: Option<Tensor>) -> Self {
|
|
Self { weight, bias }
|
|
}
|
|
}
|
|
|
|
impl Module for Linear {
|
|
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
|
let x = x.apply(&self.weight)?;
|
|
match &self.bias {
|
|
None => Ok(x),
|
|
Some(bias) => x.broadcast_add(bias),
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn linear(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
|
|
let bias = vb.get(out_dim, "bias")?.dequantize(vb.device())?;
|
|
let weight = QMatMul::new(in_dim, out_dim, vb)?;
|
|
Ok(Linear {
|
|
weight,
|
|
bias: Some(bias),
|
|
})
|
|
}
|
|
|
|
pub fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<candle_nn::LayerNorm> {
|
|
let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
|
|
let bias = vb.get(size, "bias")?.dequantize(vb.device())?;
|
|
Ok(candle_nn::LayerNorm::new(weight, bias, eps))
|
|
}
|
|
|
|
pub fn layer_norm_no_bias(size: usize, eps: f64, vb: VarBuilder) -> Result<candle_nn::LayerNorm> {
|
|
let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
|
|
Ok(candle_nn::LayerNorm::new_no_bias(weight, eps))
|
|
}
|
|
|
|
pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
|
|
let weight = QMatMul::new(in_dim, out_dim, vb)?;
|
|
Ok(Linear { weight, bias: None })
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct RmsNorm {
|
|
inner: candle_nn::RmsNorm,
|
|
span: tracing::Span,
|
|
}
|
|
|
|
impl RmsNorm {
|
|
pub fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
|
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
|
|
let weight = vb.get(size, "weight")?.dequantize(vb.device())?;
|
|
let inner = candle_nn::RmsNorm::new(weight, eps);
|
|
Ok(Self { inner, span })
|
|
}
|
|
}
|
|
|
|
impl Module for RmsNorm {
|
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
|
let _enter = self.span.enter();
|
|
self.inner.forward(x)
|
|
}
|
|
}
|