From f7b2a0391d4a96607b5f208164e365b50ad0bbf7 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 4 Aug 2023 13:32:20 +0100 Subject: [PATCH] Transpose the weight matrixes for llama2.c. (#321) --- candle-core/src/device.rs | 7 ++++++ candle-examples/examples/llama2-c/weights.rs | 23 +++++++++++++------- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 563d892b..65232839 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -101,6 +101,13 @@ impl Device { } } + pub fn is_cpu(&self) -> bool { + match self { + Self::Cpu => true, + Self::Cuda(_) => false, + } + } + pub fn is_cuda(&self) -> bool { match self { Self::Cpu => false, diff --git a/candle-examples/examples/llama2-c/weights.rs b/candle-examples/examples/llama2-c/weights.rs index 2daed057..b78418ce 100644 --- a/candle-examples/examples/llama2-c/weights.rs +++ b/candle-examples/examples/llama2-c/weights.rs @@ -105,6 +105,13 @@ impl TransformerWeights { } pub fn var_builder(&self, cfg: &Config, device: &Device) -> Result> { + // TODO: As of 2023-08-04, gemm is slower than expected when multiplying a matrix of + // size (1, k) with the transpose of a matrix of size (k, n) as it ends up transposing the + // second matrix back. We detect this case here and as a temporary hack make the weight + // matrix column major rather than row major. This ends up speeding up text generation from + // 120 token/s to 220 token/s on a Ryzen 2600X. + let tr = device.is_cpu() && !candle::utils::has_mkl(); + let tr = |x: Tensor| if tr { x.t()?.contiguous()?.t() } else { Ok(x) }; let mut ws = std::collections::HashMap::new(); let mut insert = |name: &str, t: Tensor| { ws.insert(name.to_string(), t); @@ -115,36 +122,36 @@ impl TransformerWeights { "model.embed_tokens.weight", self.token_embedding_table.clone(), ); - insert("lm_head.weight", self.token_embedding_table.clone()); + insert("lm_head.weight", tr(self.token_embedding_table.clone())?); insert("model.norm.weight", self.rms_final_weight.clone()); for layer in 0..cfg.n_layers { ws.insert( format!("model.layers.{layer}.self_attn.q_proj.weight"), - self.wq.i(layer)?, + tr(self.wq.i(layer)?)?, ); ws.insert( format!("model.layers.{layer}.self_attn.k_proj.weight"), - self.wk.i(layer)?, + tr(self.wk.i(layer)?)?, ); ws.insert( format!("model.layers.{layer}.self_attn.v_proj.weight"), - self.wv.i(layer)?, + tr(self.wv.i(layer)?)?, ); ws.insert( format!("model.layers.{layer}.self_attn.o_proj.weight"), - self.wo.i(layer)?, + tr(self.wo.i(layer)?)?, ); ws.insert( format!("model.layers.{layer}.mlp.gate_proj.weight"), - self.w1.i(layer)?, + tr(self.w1.i(layer)?)?, ); ws.insert( format!("model.layers.{layer}.mlp.down_proj.weight"), - self.w2.i(layer)?, + tr(self.w2.i(layer)?)?, ); ws.insert( format!("model.layers.{layer}.mlp.up_proj.weight"), - self.w3.i(layer)?, + tr(self.w3.i(layer)?)?, ); ws.insert( format!("model.layers.{layer}.input_layernorm.weight"),