Add a const to easily tweak the dtype used for llama internal computations.

This commit is contained in:
laurent
2023-06-30 15:01:39 +01:00
parent a243504f53
commit ed4d0959d3

View File

@ -24,6 +24,7 @@ mod var_store;
mod weights;
const MAX_SEQ_LEN: usize = 4096;
const DTYPE: DType = DType::F16;
const START_PROMPT: &str = r"
EDWARD:
I wonder how our princely father 'scaped,
@ -138,7 +139,8 @@ impl Embedding {
}
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
Ok(Tensor::embedding(indexes, &self.embeddings)?)
let embeddings = self.embeddings.to_dtype(DTYPE)?;
Ok(Tensor::embedding(indexes, &embeddings)?)
}
}
@ -152,7 +154,8 @@ impl Linear {
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.matmul(&self.weight.t()?)?;
let weight = self.weight.to_dtype(DTYPE)?;
let x = x.matmul(&weight.t()?)?;
Ok(x)
}
}
@ -167,6 +170,7 @@ impl RmsNorm {
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
// This is a no-op if x's dtype is already f32.
let x = x.to_dtype(DType::F32)?;
let (seq_len, hidden_size) = x.shape().r2()?;
let norm_x = ((&x * &x)?.sum(&[1])? / hidden_size as f64)?;
@ -178,7 +182,7 @@ impl RmsNorm {
.to_dtype(DType::F32)?
.broadcast_as((seq_len, size))?;
let x = (scale * x_normed)?;
let x = x.to_dtype(DType::F16)?;
let x = x.to_dtype(DTYPE)?;
Ok(x)
}
}
@ -339,7 +343,7 @@ impl CausalSelfAttention {
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;
let y = y.transpose(0, 1)?.reshape(&[t, c])?;
let y = y.to_dtype(DType::F16)?;
let y = y.to_dtype(DTYPE)?;
let y = self.c_proj.forward(&y)?;
Ok(y)
}