From ed4d0959d37f365deb804706143fc06d08937e20 Mon Sep 17 00:00:00 2001 From: laurent Date: Fri, 30 Jun 2023 15:01:39 +0100 Subject: [PATCH] Add a const to easily tweak the dtype used for llama internal computations. --- candle-core/examples/llama/main.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/candle-core/examples/llama/main.rs b/candle-core/examples/llama/main.rs index fac1e14f..72603bf7 100644 --- a/candle-core/examples/llama/main.rs +++ b/candle-core/examples/llama/main.rs @@ -24,6 +24,7 @@ mod var_store; mod weights; const MAX_SEQ_LEN: usize = 4096; +const DTYPE: DType = DType::F16; const START_PROMPT: &str = r" EDWARD: I wonder how our princely father 'scaped, @@ -138,7 +139,8 @@ impl Embedding { } fn forward(&self, indexes: &Tensor) -> Result { - Ok(Tensor::embedding(indexes, &self.embeddings)?) + let embeddings = self.embeddings.to_dtype(DTYPE)?; + Ok(Tensor::embedding(indexes, &embeddings)?) } } @@ -152,7 +154,8 @@ impl Linear { } fn forward(&self, x: &Tensor) -> Result { - let x = x.matmul(&self.weight.t()?)?; + let weight = self.weight.to_dtype(DTYPE)?; + let x = x.matmul(&weight.t()?)?; Ok(x) } } @@ -167,6 +170,7 @@ impl RmsNorm { } fn forward(&self, x: &Tensor) -> Result { + // This is a no-op if x's dtype is already f32. let x = x.to_dtype(DType::F32)?; let (seq_len, hidden_size) = x.shape().r2()?; let norm_x = ((&x * &x)?.sum(&[1])? / hidden_size as f64)?; @@ -178,7 +182,7 @@ impl RmsNorm { .to_dtype(DType::F32)? .broadcast_as((seq_len, size))?; let x = (scale * x_normed)?; - let x = x.to_dtype(DType::F16)?; + let x = x.to_dtype(DTYPE)?; Ok(x) } } @@ -339,7 +343,7 @@ impl CausalSelfAttention { // Convert to contiguous as matmul doesn't support strided vs for now. let y = att.matmul(&v.contiguous()?)?; let y = y.transpose(0, 1)?.reshape(&[t, c])?; - let y = y.to_dtype(DType::F16)?; + let y = y.to_dtype(DTYPE)?; let y = self.c_proj.forward(&y)?; Ok(y) }