Add the quantized mixformer model. (#953)

* Add the quantized mixformer model.

* Add the quantized option in the phi example.
This commit is contained in:
Laurent Mazare
2023-09-24 15:03:48 +01:00
committed by GitHub
parent e15862cfdb
commit 0007ae9c11
6 changed files with 418 additions and 48 deletions

View File

@ -76,3 +76,35 @@ pub fn conv2d(
let inner = candle_nn::conv2d(in_channels, out_channels, kernel_size, cfg, vs)?;
Ok(Conv2d { inner, span })
}
// QMatMul wrapper adding some tracing.
pub struct QMatMul {
inner: candle::quantized::QMatMul,
span: tracing::Span,
}
impl QMatMul {
pub fn new(
out_dim: usize,
in_dim: usize,
vb: crate::quantized_var_builder::VarBuilder,
) -> Result<Self> {
let ws = vb.get((in_dim, out_dim), "weight")?;
let inner = candle::quantized::QMatMul::from_arc(ws);
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
Ok(Self { inner, span })
}
}
impl Module for QMatMul {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(xs)
}
}
impl std::fmt::Debug for QMatMul {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "QMatMul")
}
}