mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Add the quantized mixformer model. (#953)
* Add the quantized mixformer model. * Add the quantized option in the phi example.
This commit is contained in:
@ -76,3 +76,35 @@ pub fn conv2d(
|
||||
let inner = candle_nn::conv2d(in_channels, out_channels, kernel_size, cfg, vs)?;
|
||||
Ok(Conv2d { inner, span })
|
||||
}
|
||||
|
||||
// QMatMul wrapper adding some tracing.
|
||||
pub struct QMatMul {
|
||||
inner: candle::quantized::QMatMul,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl QMatMul {
|
||||
pub fn new(
|
||||
out_dim: usize,
|
||||
in_dim: usize,
|
||||
vb: crate::quantized_var_builder::VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let ws = vb.get((in_dim, out_dim), "weight")?;
|
||||
let inner = candle::quantized::QMatMul::from_arc(ws);
|
||||
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for QMatMul {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(xs)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for QMatMul {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "QMatMul")
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user