diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index 61fabc63..94e6bd23 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -232,19 +232,25 @@ impl QTensor { } #[derive(Clone, Debug)] -pub struct QMatMul(std::sync::Arc); +pub enum QMatMul { + QTensor(std::sync::Arc), + Tensor(Tensor), +} impl QMatMul { - pub fn from_arc(qtensor: std::sync::Arc) -> Self { - Self(qtensor) + pub fn from_arc(qtensor: std::sync::Arc) -> Result { + let t = match qtensor.dtype() { + GgmlDType::F32 | GgmlDType::F16 => { + let tensor = qtensor.dequantize(&Device::Cpu)?; + Self::Tensor(tensor) + } + _ => Self::QTensor(qtensor), + }; + Ok(t) } - pub fn from_qtensor(qtensor: QTensor) -> Self { - Self(std::sync::Arc::new(qtensor)) - } - - pub fn inner(&self) -> &std::sync::Arc { - &self.0 + pub fn from_qtensor(qtensor: QTensor) -> Result { + Self::from_arc(std::sync::Arc::new(qtensor)) } } @@ -289,6 +295,9 @@ impl crate::CustomOp1 for QTensor { impl QMatMul { pub fn forward(&self, xs: &Tensor) -> Result { - xs.apply_op1_no_bwd(self.0.as_ref()) + match self { + Self::QTensor(t) => xs.apply_op1_no_bwd(t.as_ref()), + Self::Tensor(t) => xs.matmul(&t.t()?), + } } } diff --git a/candle-core/tests/quantized_tests.rs b/candle-core/tests/quantized_tests.rs index f21d7767..06f1ee47 100644 --- a/candle-core/tests/quantized_tests.rs +++ b/candle-core/tests/quantized_tests.rs @@ -43,7 +43,7 @@ fn quantized_matmul() -> Result<()> { ); let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?; - let matmul = quantized::QMatMul::from_qtensor(qtensor); + let matmul = quantized::QMatMul::from_qtensor(qtensor)?; let res = matmul.forward(&tensor_lhs)?; assert_eq!( to_vec2_round(&res, 0)?, @@ -91,7 +91,7 @@ fn quantized_matmul_neg() -> Result<()> { ); let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?; - let matmul = quantized::QMatMul::from_qtensor(qtensor); + let matmul = quantized::QMatMul::from_qtensor(qtensor)?; let res = matmul.forward(&tensor_lhs)?; assert_eq!( to_vec2_round(&res, 0)?, @@ -576,7 +576,7 @@ fn quantized_matmul_q2k() -> Result<()> { assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); let rhs = quantized::QTensor::quantize::(&rhs)?; - let rhs = quantized::QMatMul::from_qtensor(rhs); + let rhs = quantized::QMatMul::from_qtensor(rhs)?; let mm = rhs.forward(&lhs)?; assert_eq!(mm.dims(), [m, n]); @@ -602,7 +602,7 @@ fn quantized_matmul_q3k() -> Result<()> { assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); let rhs = quantized::QTensor::quantize::(&rhs)?; - let rhs = quantized::QMatMul::from_qtensor(rhs); + let rhs = quantized::QMatMul::from_qtensor(rhs)?; let mm = rhs.forward(&lhs)?; assert_eq!(mm.dims(), [m, n]); @@ -628,7 +628,7 @@ fn quantized_matmul_q4k() -> Result<()> { assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); let rhs = quantized::QTensor::quantize::(&rhs)?; - let rhs = quantized::QMatMul::from_qtensor(rhs); + let rhs = quantized::QMatMul::from_qtensor(rhs)?; let mm = rhs.forward(&lhs)?; assert_eq!(mm.dims(), [m, n]); @@ -654,7 +654,7 @@ fn quantized_matmul_q5k() -> Result<()> { assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); let rhs = quantized::QTensor::quantize::(&rhs)?; - let rhs = quantized::QMatMul::from_qtensor(rhs); + let rhs = quantized::QMatMul::from_qtensor(rhs)?; let mm = rhs.forward(&lhs)?; assert_eq!(mm.dims(), [m, n]); @@ -681,7 +681,7 @@ fn quantized_matmul_q6k() -> Result<()> { assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); let rhs = quantized::QTensor::quantize::(&rhs)?; - let rhs = quantized::QMatMul::from_qtensor(rhs); + let rhs = quantized::QMatMul::from_qtensor(rhs)?; let mm = rhs.forward(&lhs)?; assert_eq!(mm.dims(), [m, n]); diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 5249ed34..5d4b624e 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -484,17 +484,25 @@ fn main() -> Result<()> { println!("No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav"); dataset.get("samples_jfk.wav")? }; - let config = if args.quantized { - repo.get("config-tiny.json")? + let (config, tokenizer, model) = if args.quantized { + let ext = match args.model { + WhichModel::TinyEn => "tiny-en", + WhichModel::Tiny => "tiny", + _ => unimplemented!("no quantized support for {:?}", args.model), + }; + ( + repo.get(&format!("config-{ext}.json"))?, + repo.get(&format!("tokenizer-{ext}.json"))?, + repo.get(&format!("model-{ext}-q40.gguf"))?, + ) } else { - repo.get("config.json")? + ( + repo.get("config.json")?, + repo.get("tokenizer.json")?, + repo.get("model.safetensors")?, + ) }; - let model = if args.quantized { - repo.get("model-tiny-q40.gguf")? - } else { - repo.get("model.safetensors")? - }; - (config, repo.get("tokenizer.json")?, model, sample) + (config, tokenizer, model, sample) }; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; diff --git a/candle-nn/examples/cpu_benchmarks.rs b/candle-nn/examples/cpu_benchmarks.rs index 204a7109..e58ea727 100644 --- a/candle-nn/examples/cpu_benchmarks.rs +++ b/candle-nn/examples/cpu_benchmarks.rs @@ -206,7 +206,7 @@ impl Benchmark for QMatMul { fn preprocess() -> Result { let zeros = vec![candle::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32]; let mm = candle::quantized::QTensor::new(zeros, (4096, 11008))?; - let mm = candle::quantized::QMatMul::from_qtensor(mm); + let mm = candle::quantized::QMatMul::from_qtensor(mm)?; let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?; Ok((mm, arg)) } diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 55b7a888..64b6dd2c 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -867,7 +867,7 @@ impl PyQTensor { /// Performs a quantized matrix multiplication, with the quantized tensor as the right hand side. /// &RETURNS&: Tensor fn matmul_t(&self, lhs: &PyTensor) -> PyResult { - let qmatmul = ::candle::quantized::QMatMul::from_arc(self.0.clone()); + let qmatmul = ::candle::quantized::QMatMul::from_arc(self.0.clone()).map_err(wrap_err)?; let res = qmatmul.forward(lhs).map_err(wrap_err)?; Ok(PyTensor(res)) } diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index 2988b0fb..8ac1d460 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -33,10 +33,10 @@ struct QMatMul { } impl QMatMul { - fn from_qtensor(qtensor: QTensor) -> Self { - let inner = candle::quantized::QMatMul::from_qtensor(qtensor); + fn from_qtensor(qtensor: QTensor) -> Result { + let inner = candle::quantized::QMatMul::from_qtensor(qtensor)?; let span = tracing::span!(tracing::Level::TRACE, "qmatmul"); - Self { inner, span } + Ok(Self { inner, span }) } fn forward(&self, xs: &Tensor) -> Result { @@ -217,14 +217,14 @@ impl ModelWeights { let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp"); layers.push(LayerWeights { - attention_wq: QMatMul::from_qtensor(attention_wq), - attention_wk: QMatMul::from_qtensor(attention_wk), - attention_wv: QMatMul::from_qtensor(attention_wv), - attention_wo: QMatMul::from_qtensor(attention_wo), + attention_wq: QMatMul::from_qtensor(attention_wq)?, + attention_wk: QMatMul::from_qtensor(attention_wk)?, + attention_wv: QMatMul::from_qtensor(attention_wv)?, + attention_wo: QMatMul::from_qtensor(attention_wo)?, attention_norm: RmsNorm::new(attention_norm, 1e-5)?, - feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1), - feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2), - feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3), + feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, + feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, + feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?, ffn_norm: RmsNorm::new(ffn_norm, 1e-5)?, n_head: ct.hparams.n_head as usize, n_kv_head: ct.hparams.n_head as usize / gqa, @@ -243,7 +243,7 @@ impl ModelWeights { tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize), layers, norm, - output: QMatMul::from_qtensor(output), + output: QMatMul::from_qtensor(output)?, masks: HashMap::new(), span, span_output, @@ -294,14 +294,14 @@ impl ModelWeights { let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp"); layers.push(LayerWeights { - attention_wq: QMatMul::from_qtensor(attention_wq), - attention_wk: QMatMul::from_qtensor(attention_wk), - attention_wv: QMatMul::from_qtensor(attention_wv), - attention_wo: QMatMul::from_qtensor(attention_wo), + attention_wq: QMatMul::from_qtensor(attention_wq)?, + attention_wk: QMatMul::from_qtensor(attention_wk)?, + attention_wv: QMatMul::from_qtensor(attention_wv)?, + attention_wo: QMatMul::from_qtensor(attention_wo)?, attention_norm: RmsNorm::new(attention_norm, rms_norm_eps)?, - feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1), - feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2), - feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3), + feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, + feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, + feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?, ffn_norm: RmsNorm::new(ffn_norm, rms_norm_eps)?, n_head: head_count, n_kv_head: head_count_kv, @@ -320,7 +320,7 @@ impl ModelWeights { tok_embeddings: Embedding::new(tok_embeddings, embedding_length), layers, norm, - output: QMatMul::from_qtensor(output), + output: QMatMul::from_qtensor(output)?, masks: HashMap::new(), span, span_output, diff --git a/candle-transformers/src/models/with_tracing.rs b/candle-transformers/src/models/with_tracing.rs index 6a6c69e7..09a243ac 100644 --- a/candle-transformers/src/models/with_tracing.rs +++ b/candle-transformers/src/models/with_tracing.rs @@ -90,7 +90,7 @@ impl QMatMul { vb: crate::quantized_var_builder::VarBuilder, ) -> Result { let ws = vb.get((in_dim, out_dim), "weight")?; - let inner = candle::quantized::QMatMul::from_arc(ws); + let inner = candle::quantized::QMatMul::from_arc(ws)?; let span = tracing::span!(tracing::Level::TRACE, "qmatmul"); Ok(Self { inner, span }) } diff --git a/candle-wasm-tests/tests/quantized_tests.rs b/candle-wasm-tests/tests/quantized_tests.rs index 84ca05bc..8beeab60 100644 --- a/candle-wasm-tests/tests/quantized_tests.rs +++ b/candle-wasm-tests/tests/quantized_tests.rs @@ -41,7 +41,7 @@ fn quantized_matmul_neg() -> Result<()> { ); let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?; - let matmul = quantized::QMatMul::from_qtensor(qtensor); + let matmul = quantized::QMatMul::from_qtensor(qtensor)?; let res = matmul.forward(&tensor_lhs)?; assert_eq!( to_vec2_round(&res, 0)?,