Avoid tensor copying in the quantized example. (#1770)

This commit is contained in:
Laurent Mazare
2024-02-27 20:32:30 +01:00
committed by GitHub
parent 5e526abc8c
commit 205767f9de

View File

@ -157,16 +157,16 @@ struct LayerWeights {
head_dim: usize, head_dim: usize,
cos: Tensor, cos: Tensor,
sin: Tensor, sin: Tensor,
neg_inf: Tensor,
kv_cache: Option<(Tensor, Tensor)>, kv_cache: Option<(Tensor, Tensor)>,
span_attn: tracing::Span, span_attn: tracing::Span,
span_rot: tracing::Span, span_rot: tracing::Span,
span_mlp: tracing::Span, span_mlp: tracing::Span,
} }
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> { fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result<Tensor> {
let shape = mask.shape(); let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m) Ok(m)
} }
@ -240,7 +240,7 @@ impl LayerWeights {
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = mask.broadcast_as(att.shape())?; let mask = mask.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; let att = masked_fill(&att, &mask, &self.neg_inf)?;
let att = candle_nn::ops::softmax_last_dim(&att)?; let att = candle_nn::ops::softmax_last_dim(&att)?;
// Convert to contiguous as matmul doesn't support strided vs for now. // Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?; let y = att.matmul(&v.contiguous()?)?;
@ -298,6 +298,7 @@ impl ModelWeights {
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> { pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize; let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
let (cos, sin) = precomput_freqs_cis(head_dim, 10000., &ct.device)?; let (cos, sin) = precomput_freqs_cis(head_dim, 10000., &ct.device)?;
let neg_inf = Tensor::new(f32::NEG_INFINITY, &ct.device)?;
let tok_embeddings = ct.remove("tok_embeddings.weight")?; let tok_embeddings = ct.remove("tok_embeddings.weight")?;
let tok_embeddings = tok_embeddings.dequantize(&ct.device)?; let tok_embeddings = tok_embeddings.dequantize(&ct.device)?;
let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?; let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
@ -337,6 +338,7 @@ impl ModelWeights {
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize, head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
cos: cos.clone(), cos: cos.clone(),
sin: sin.clone(), sin: sin.clone(),
neg_inf: neg_inf.clone(),
kv_cache: None, kv_cache: None,
span_attn, span_attn,
span_rot, span_rot,
@ -385,6 +387,7 @@ impl ModelWeights {
.and_then(|m| m.to_f32()) .and_then(|m| m.to_f32())
.unwrap_or(10000f32); .unwrap_or(10000f32);
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?; let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?;
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
let tok_embeddings = tok_embeddings.dequantize(device)?; let tok_embeddings = tok_embeddings.dequantize(device)?;
@ -455,6 +458,7 @@ impl ModelWeights {
head_dim: embedding_length / head_count, head_dim: embedding_length / head_count,
cos: cos.clone(), cos: cos.clone(),
sin: sin.clone(), sin: sin.clone(),
neg_inf: neg_inf.clone(),
kv_cache: None, kv_cache: None,
span_attn, span_attn,
span_rot, span_rot,