Add vecdot for q6k-q8k. (#476)

* Add vecdot for q6k-q8k.

* Add some testing for q8k.

* Use QMatMul for the output layer.
This commit is contained in:
Laurent Mazare
2023-08-16 20:59:40 +01:00
committed by GitHub
parent 3bedba1fce
commit 098909de40
3 changed files with 80 additions and 6 deletions

View File

@ -155,8 +155,7 @@ struct ModelWeights {
tok_embeddings: Embedding,
layers: Vec<LayerWeights>,
norm: RmsNorm,
// TODO: Switch to using QMatMul instead of linear once we have support for Q6K/Q8K.
output: candle_nn::Linear,
output: QMatMul,
masks: HashMap<usize, Tensor>,
span: tracing::Span,
span_output: tracing::Span,
@ -197,7 +196,6 @@ impl ModelWeights {
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
let norm = RmsNorm::new(ct.remove("norm.weight")?)?;
let output = ct.remove("output.weight")?;
let output = candle_nn::Linear::new(output.dequantize(cpu)?, None);
let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
for layer_idx in 0..ct.hparams.n_layer {
let prefix = format!("layers.{layer_idx}");
@ -239,7 +237,7 @@ impl ModelWeights {
tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize),
layers,
norm,
output,
output: QMatMul::from_qtensor(output),
masks: HashMap::new(),
span,
span_output,