From 0f679fe42e7d1fb2aba7b3db9bcfb91d3c6c0026 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 6 Jul 2023 19:23:54 +0100 Subject: [PATCH] Fix some shape issues in falcon. (#95) * Fix some shape issues. * Use different dtypes. --- candle-examples/examples/falcon/main.rs | 5 ++++- candle-examples/examples/falcon/model.rs | 23 +++++++++++++++++------ 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/candle-examples/examples/falcon/main.rs b/candle-examples/examples/falcon/main.rs index 832183ad..c66a7784 100644 --- a/candle-examples/examples/falcon/main.rs +++ b/candle-examples/examples/falcon/main.rs @@ -10,7 +10,10 @@ use clap::Parser; mod model; use model::{Config, Falcon, VarBuilder}; -const DTYPE: DType = DType::F16; +#[cfg(feature = "mkl")] +const DTYPE: DType = DType::F32; +#[cfg(not(feature = "mkl"))] +const DTYPE: DType = DType::BF16; #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] diff --git a/candle-examples/examples/falcon/model.rs b/candle-examples/examples/falcon/model.rs index d7a2e8b8..efab97ca 100644 --- a/candle-examples/examples/falcon/model.rs +++ b/candle-examples/examples/falcon/model.rs @@ -421,13 +421,24 @@ impl FalconAttention { }; let mask = masked_fill(&mask.to_dtype(DType::F32)?, mask, -1e9)?.to_dtype(query.dtype())?; // TODO: layer_past, use_cache? - let query = query.reshape((b_sz, self.num_heads, q_len, head_dim))?; - let key = key.reshape((b_sz, self.n_head_kv, q_len, head_dim))?; - let value = value.reshape((b_sz, self.n_head_kv, q_len, head_dim))?; + let query = query.reshape((b_sz * self.num_heads, q_len, head_dim))?; + let key = key.reshape((b_sz * self.n_head_kv, q_len, head_dim))?; + let value = value.reshape((b_sz * self.n_head_kv, q_len, head_dim))?; + + let (key, value) = if self.n_head_kv == 1 { + ( + key.broadcast_as(query.dims())?, + value.broadcast_as(query.dims())?, + ) + } else { + (key, value) + }; // Only handle alibi is None here, and non-flash attention. let attention_scores = (query.matmul(&key.t()?)? * self.inv_norm_factor)?; - let attention_scores = (attention_scores + mask)?.softmax(D::Minus1)?; + let attention_scores = attention_scores + .broadcast_add(&mask.squeeze(1)?)? + .softmax(D::Minus1)?; let attn_output = attention_scores .matmul(&value)? .reshape((b_sz, self.num_heads, q_len, head_dim))? @@ -459,8 +470,8 @@ impl FalconMlp { } fn forward(&self, x: &Tensor) -> Result { - let x = self.dense_4h_to_h.forward(x)?.gelu()?; - let x = self.dense_h_to_4h.forward(&x)?; + let x = self.dense_h_to_4h.forward(x)?.gelu()?; + let x = self.dense_4h_to_h.forward(&x)?; Ok(x) } }