From 03dffe9ecc6a857d7ad449b1b0d69dc4f82c5b32 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 7 Jul 2023 17:55:21 +0100 Subject: [PATCH] Use F32 for the reduce ops. (#105) --- candle-examples/examples/falcon/model.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/candle-examples/examples/falcon/model.rs b/candle-examples/examples/falcon/model.rs index a877bd69..e7c53e50 100644 --- a/candle-examples/examples/falcon/model.rs +++ b/candle-examples/examples/falcon/model.rs @@ -122,12 +122,15 @@ impl LayerNorm { } fn forward(&self, x: &Tensor) -> Result { + let dtype = x.dtype(); let (_bsize, _seq_len, hidden_size) = x.shape().r3()?; + let x = x.to_dtype(DType::F32)?; let mean_x = (x.sum(&[2])? / hidden_size as f64)?; let x = x.broadcast_sub(&mean_x)?; let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?; let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; let x = x_normed + .to_dtype(dtype)? .broadcast_mul(&self.weight)? .broadcast_add(&self.bias)?; Ok(x) @@ -470,7 +473,9 @@ impl FalconAttention { let attention_scores = (query.matmul(&key.t()?)? * self.inv_norm_factor)?; let attention_scores = attention_scores .broadcast_add(&mask.squeeze(1)?)? - .softmax(D::Minus1)?; + .to_dtype(DType::F32)? + .softmax(D::Minus1)? + .to_dtype(x.dtype())?; let attn_output = attention_scores .matmul(&value)? .reshape((b_sz, self.num_heads, seq_len, head_dim))?