mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Merge pull request #47 from LaurentMazare/llama-f32
Add a const to easily tweak the dtype used by llama
This commit is contained in:
@ -24,6 +24,7 @@ mod var_store;
|
|||||||
mod weights;
|
mod weights;
|
||||||
|
|
||||||
const MAX_SEQ_LEN: usize = 4096;
|
const MAX_SEQ_LEN: usize = 4096;
|
||||||
|
const DTYPE: DType = DType::F16;
|
||||||
const START_PROMPT: &str = r"
|
const START_PROMPT: &str = r"
|
||||||
EDWARD:
|
EDWARD:
|
||||||
I wonder how our princely father 'scaped,
|
I wonder how our princely father 'scaped,
|
||||||
@ -138,7 +139,8 @@ impl Embedding {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
|
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
|
||||||
Ok(Tensor::embedding(indexes, &self.embeddings)?)
|
let embeddings = self.embeddings.to_dtype(DTYPE)?;
|
||||||
|
Ok(Tensor::embedding(indexes, &embeddings)?)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -152,7 +154,8 @@ impl Linear {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
let x = x.matmul(&self.weight.t()?)?;
|
let weight = self.weight.to_dtype(DTYPE)?;
|
||||||
|
let x = x.matmul(&weight.t()?)?;
|
||||||
Ok(x)
|
Ok(x)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -167,6 +170,7 @@ impl RmsNorm {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
// This is a no-op if x's dtype is already f32.
|
||||||
let x = x.to_dtype(DType::F32)?;
|
let x = x.to_dtype(DType::F32)?;
|
||||||
let (seq_len, hidden_size) = x.shape().r2()?;
|
let (seq_len, hidden_size) = x.shape().r2()?;
|
||||||
let norm_x = ((&x * &x)?.sum(&[1])? / hidden_size as f64)?;
|
let norm_x = ((&x * &x)?.sum(&[1])? / hidden_size as f64)?;
|
||||||
@ -178,7 +182,7 @@ impl RmsNorm {
|
|||||||
.to_dtype(DType::F32)?
|
.to_dtype(DType::F32)?
|
||||||
.broadcast_as((seq_len, size))?;
|
.broadcast_as((seq_len, size))?;
|
||||||
let x = (scale * x_normed)?;
|
let x = (scale * x_normed)?;
|
||||||
let x = x.to_dtype(DType::F16)?;
|
let x = x.to_dtype(DTYPE)?;
|
||||||
Ok(x)
|
Ok(x)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -339,7 +343,7 @@ impl CausalSelfAttention {
|
|||||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||||
let y = att.matmul(&v.contiguous()?)?;
|
let y = att.matmul(&v.contiguous()?)?;
|
||||||
let y = y.transpose(0, 1)?.reshape(&[t, c])?;
|
let y = y.transpose(0, 1)?.reshape(&[t, c])?;
|
||||||
let y = y.to_dtype(DType::F16)?;
|
let y = y.to_dtype(DTYPE)?;
|
||||||
let y = self.c_proj.forward(&y)?;
|
let y = self.c_proj.forward(&y)?;
|
||||||
Ok(y)
|
Ok(y)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user