mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Expose some helper functions to create quantized models. (#1837)
This commit is contained in:
@ -116,6 +116,12 @@ impl QMatMul {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
|
||||
pub fn from_weights(ws: std::sync::Arc<candle::quantized::QTensor>) -> Result<Self> {
|
||||
let inner = candle::quantized::QMatMul::from_arc(ws)?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for QMatMul {
|
||||
|
@ -35,6 +35,14 @@ pub struct Linear {
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
pub fn from_arc(
|
||||
weight: std::sync::Arc<candle::quantized::QTensor>,
|
||||
bias: Option<Tensor>,
|
||||
) -> Result<Self> {
|
||||
let weight = QMatMul::from_weights(weight)?;
|
||||
Ok(Self { weight, bias })
|
||||
}
|
||||
|
||||
pub fn from_weights(weight: QMatMul, bias: Option<Tensor>) -> Self {
|
||||
Self { weight, bias }
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ use candle::{Device, Result, Shape};
|
||||
use std::sync::Arc;
|
||||
|
||||
// VarBuilder specialized for QTensors
|
||||
#[derive(Clone)]
|
||||
pub struct VarBuilder {
|
||||
data: Arc<std::collections::HashMap<String, Arc<QTensor>>>,
|
||||
path: Vec<String>,
|
||||
|
Reference in New Issue
Block a user