mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Fixing quantized llama demo on metal. (#1703)
This commit is contained in:
@ -16,7 +16,7 @@ struct RmsNorm {
|
||||
impl RmsNorm {
|
||||
fn new(scale: QTensor, eps: f32) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
|
||||
let scale = scale.dequantize(&Device::Cpu)?;
|
||||
let scale = scale.dequantize(&scale.device())?;
|
||||
let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64);
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
@ -275,13 +275,17 @@ pub struct ModelWeights {
|
||||
span_output: tracing::Span,
|
||||
}
|
||||
|
||||
fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tensor)> {
|
||||
fn precomput_freqs_cis(
|
||||
head_dim: usize,
|
||||
freq_base: f32,
|
||||
device: &Device,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let theta: Vec<_> = (0..head_dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
|
||||
.collect();
|
||||
let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?;
|
||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)?
|
||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((MAX_SEQ_LEN, 1))?
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
@ -292,11 +296,10 @@ fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tenso
|
||||
|
||||
impl ModelWeights {
|
||||
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
|
||||
let cpu = &Device::Cpu;
|
||||
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
|
||||
let (cos, sin) = precomput_freqs_cis(head_dim, 10000.)?;
|
||||
let (cos, sin) = precomput_freqs_cis(head_dim, 10000., &ct.device)?;
|
||||
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(&ct.device)?;
|
||||
let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
|
||||
let output = ct.remove("output.weight")?;
|
||||
let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
|
||||
@ -358,7 +361,6 @@ impl ModelWeights {
|
||||
reader: &mut R,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
let cpu = &Device::Cpu;
|
||||
let md_get = |s: &str| match ct.metadata.get(s) {
|
||||
None => candle::bail!("cannot find {s} in metadata"),
|
||||
Some(v) => Ok(v),
|
||||
@ -382,10 +384,10 @@ impl ModelWeights {
|
||||
let rope_freq_base = md_get("llama.rope.freq_base")
|
||||
.and_then(|m| m.to_f32())
|
||||
.unwrap_or(10000f32);
|
||||
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?;
|
||||
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?;
|
||||
|
||||
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
||||
let norm = RmsNorm::new(
|
||||
ct.tensor(reader, "output_norm.weight", device)?,
|
||||
rms_norm_eps,
|
||||
@ -472,14 +474,14 @@ impl ModelWeights {
|
||||
})
|
||||
}
|
||||
|
||||
fn mask(&mut self, t: usize) -> Result<Tensor> {
|
||||
fn mask(&mut self, t: usize, device: &Device) -> Result<Tensor> {
|
||||
if let Some(mask) = self.masks.get(&t) {
|
||||
Ok(mask.clone())
|
||||
} else {
|
||||
let mask: Vec<_> = (0..t)
|
||||
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?;
|
||||
let mask = Tensor::from_slice(&mask, (t, t), device)?;
|
||||
self.masks.insert(t, mask.clone());
|
||||
Ok(mask)
|
||||
}
|
||||
@ -487,7 +489,7 @@ impl ModelWeights {
|
||||
|
||||
pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (_b_sz, seq_len) = x.dims2()?;
|
||||
let mask = self.mask(seq_len)?;
|
||||
let mask = self.mask(seq_len, x.device())?;
|
||||
let _enter = self.span.enter();
|
||||
let mut layer_in = self.tok_embeddings.forward(x)?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
|
Reference in New Issue
Block a user