mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Expose some helper functions to create quantized models. (#1837)
This commit is contained in:
@ -116,6 +116,12 @@ impl QMatMul {
|
|||||||
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
|
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
|
||||||
Ok(Self { inner, span })
|
Ok(Self { inner, span })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn from_weights(ws: std::sync::Arc<candle::quantized::QTensor>) -> Result<Self> {
|
||||||
|
let inner = candle::quantized::QMatMul::from_arc(ws)?;
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
|
||||||
|
Ok(Self { inner, span })
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for QMatMul {
|
impl Module for QMatMul {
|
||||||
|
@ -35,6 +35,14 @@ pub struct Linear {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Linear {
|
impl Linear {
|
||||||
|
pub fn from_arc(
|
||||||
|
weight: std::sync::Arc<candle::quantized::QTensor>,
|
||||||
|
bias: Option<Tensor>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let weight = QMatMul::from_weights(weight)?;
|
||||||
|
Ok(Self { weight, bias })
|
||||||
|
}
|
||||||
|
|
||||||
pub fn from_weights(weight: QMatMul, bias: Option<Tensor>) -> Self {
|
pub fn from_weights(weight: QMatMul, bias: Option<Tensor>) -> Self {
|
||||||
Self { weight, bias }
|
Self { weight, bias }
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@ use candle::{Device, Result, Shape};
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
// VarBuilder specialized for QTensors
|
// VarBuilder specialized for QTensors
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct VarBuilder {
|
pub struct VarBuilder {
|
||||||
data: Arc<std::collections::HashMap<String, Arc<QTensor>>>,
|
data: Arc<std::collections::HashMap<String, Arc<QTensor>>>,
|
||||||
path: Vec<String>,
|
path: Vec<String>,
|
||||||
|
Reference in New Issue
Block a user