mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add vecdot for q6k-q8k. (#476)
* Add vecdot for q6k-q8k. * Add some testing for q8k. * Use QMatMul for the output layer.
This commit is contained in:
@ -155,8 +155,7 @@ struct ModelWeights {
|
||||
tok_embeddings: Embedding,
|
||||
layers: Vec<LayerWeights>,
|
||||
norm: RmsNorm,
|
||||
// TODO: Switch to using QMatMul instead of linear once we have support for Q6K/Q8K.
|
||||
output: candle_nn::Linear,
|
||||
output: QMatMul,
|
||||
masks: HashMap<usize, Tensor>,
|
||||
span: tracing::Span,
|
||||
span_output: tracing::Span,
|
||||
@ -197,7 +196,6 @@ impl ModelWeights {
|
||||
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
|
||||
let norm = RmsNorm::new(ct.remove("norm.weight")?)?;
|
||||
let output = ct.remove("output.weight")?;
|
||||
let output = candle_nn::Linear::new(output.dequantize(cpu)?, None);
|
||||
let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
|
||||
for layer_idx in 0..ct.hparams.n_layer {
|
||||
let prefix = format!("layers.{layer_idx}");
|
||||
@ -239,7 +237,7 @@ impl ModelWeights {
|
||||
tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize),
|
||||
layers,
|
||||
norm,
|
||||
output,
|
||||
output: QMatMul::from_qtensor(output),
|
||||
masks: HashMap::new(),
|
||||
span,
|
||||
span_output,
|
||||
|
Reference in New Issue
Block a user