mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add a const to easily tweak the dtype used for llama internal computations.
This commit is contained in:
@ -24,6 +24,7 @@ mod var_store;
|
||||
mod weights;
|
||||
|
||||
const MAX_SEQ_LEN: usize = 4096;
|
||||
const DTYPE: DType = DType::F16;
|
||||
const START_PROMPT: &str = r"
|
||||
EDWARD:
|
||||
I wonder how our princely father 'scaped,
|
||||
@ -138,7 +139,8 @@ impl Embedding {
|
||||
}
|
||||
|
||||
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
|
||||
Ok(Tensor::embedding(indexes, &self.embeddings)?)
|
||||
let embeddings = self.embeddings.to_dtype(DTYPE)?;
|
||||
Ok(Tensor::embedding(indexes, &embeddings)?)
|
||||
}
|
||||
}
|
||||
|
||||
@ -152,7 +154,8 @@ impl Linear {
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = x.matmul(&self.weight.t()?)?;
|
||||
let weight = self.weight.to_dtype(DTYPE)?;
|
||||
let x = x.matmul(&weight.t()?)?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
@ -167,6 +170,7 @@ impl RmsNorm {
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
// This is a no-op if x's dtype is already f32.
|
||||
let x = x.to_dtype(DType::F32)?;
|
||||
let (seq_len, hidden_size) = x.shape().r2()?;
|
||||
let norm_x = ((&x * &x)?.sum(&[1])? / hidden_size as f64)?;
|
||||
@ -178,7 +182,7 @@ impl RmsNorm {
|
||||
.to_dtype(DType::F32)?
|
||||
.broadcast_as((seq_len, size))?;
|
||||
let x = (scale * x_normed)?;
|
||||
let x = x.to_dtype(DType::F16)?;
|
||||
let x = x.to_dtype(DTYPE)?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
@ -339,7 +343,7 @@ impl CausalSelfAttention {
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = att.matmul(&v.contiguous()?)?;
|
||||
let y = y.transpose(0, 1)?.reshape(&[t, c])?;
|
||||
let y = y.to_dtype(DType::F16)?;
|
||||
let y = y.to_dtype(DTYPE)?;
|
||||
let y = self.c_proj.forward(&y)?;
|
||||
Ok(y)
|
||||
}
|
||||
|
Reference in New Issue
Block a user