Expose some helper functions to create quantized models. (#1837)

This commit is contained in:
Laurent Mazare
2024-03-12 11:30:24 +01:00
committed by GitHub
parent df5f69444e
commit ff03fd3fb3
3 changed files with 15 additions and 0 deletions

View File

@ -116,6 +116,12 @@ impl QMatMul {
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
Ok(Self { inner, span })
}
pub fn from_weights(ws: std::sync::Arc<candle::quantized::QTensor>) -> Result<Self> {
let inner = candle::quantized::QMatMul::from_arc(ws)?;
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
Ok(Self { inner, span })
}
}
impl Module for QMatMul {

View File

@ -35,6 +35,14 @@ pub struct Linear {
}
impl Linear {
pub fn from_arc(
weight: std::sync::Arc<candle::quantized::QTensor>,
bias: Option<Tensor>,
) -> Result<Self> {
let weight = QMatMul::from_weights(weight)?;
Ok(Self { weight, bias })
}
pub fn from_weights(weight: QMatMul, bias: Option<Tensor>) -> Self {
Self { weight, bias }
}

View File

@ -3,6 +3,7 @@ use candle::{Device, Result, Shape};
use std::sync::Arc;
// VarBuilder specialized for QTensors
#[derive(Clone)]
pub struct VarBuilder {
data: Arc<std::collections::HashMap<String, Arc<QTensor>>>,
path: Vec<String>,