mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Avoid tensor copying in the quantized example. (#1770)
This commit is contained in:
@ -157,16 +157,16 @@ struct LayerWeights {
|
|||||||
head_dim: usize,
|
head_dim: usize,
|
||||||
cos: Tensor,
|
cos: Tensor,
|
||||||
sin: Tensor,
|
sin: Tensor,
|
||||||
|
neg_inf: Tensor,
|
||||||
kv_cache: Option<(Tensor, Tensor)>,
|
kv_cache: Option<(Tensor, Tensor)>,
|
||||||
span_attn: tracing::Span,
|
span_attn: tracing::Span,
|
||||||
span_rot: tracing::Span,
|
span_rot: tracing::Span,
|
||||||
span_mlp: tracing::Span,
|
span_mlp: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result<Tensor> {
|
||||||
let shape = mask.shape();
|
let shape = mask.shape();
|
||||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?;
|
||||||
let m = mask.where_cond(&on_true, on_false)?;
|
|
||||||
Ok(m)
|
Ok(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -240,7 +240,7 @@ impl LayerWeights {
|
|||||||
|
|
||||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||||
let mask = mask.broadcast_as(att.shape())?;
|
let mask = mask.broadcast_as(att.shape())?;
|
||||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
let att = masked_fill(&att, &mask, &self.neg_inf)?;
|
||||||
let att = candle_nn::ops::softmax_last_dim(&att)?;
|
let att = candle_nn::ops::softmax_last_dim(&att)?;
|
||||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||||
let y = att.matmul(&v.contiguous()?)?;
|
let y = att.matmul(&v.contiguous()?)?;
|
||||||
@ -298,6 +298,7 @@ impl ModelWeights {
|
|||||||
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
|
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
|
||||||
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
|
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
|
||||||
let (cos, sin) = precomput_freqs_cis(head_dim, 10000., &ct.device)?;
|
let (cos, sin) = precomput_freqs_cis(head_dim, 10000., &ct.device)?;
|
||||||
|
let neg_inf = Tensor::new(f32::NEG_INFINITY, &ct.device)?;
|
||||||
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
|
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
|
||||||
let tok_embeddings = tok_embeddings.dequantize(&ct.device)?;
|
let tok_embeddings = tok_embeddings.dequantize(&ct.device)?;
|
||||||
let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
|
let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
|
||||||
@ -337,6 +338,7 @@ impl ModelWeights {
|
|||||||
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
|
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
|
||||||
cos: cos.clone(),
|
cos: cos.clone(),
|
||||||
sin: sin.clone(),
|
sin: sin.clone(),
|
||||||
|
neg_inf: neg_inf.clone(),
|
||||||
kv_cache: None,
|
kv_cache: None,
|
||||||
span_attn,
|
span_attn,
|
||||||
span_rot,
|
span_rot,
|
||||||
@ -385,6 +387,7 @@ impl ModelWeights {
|
|||||||
.and_then(|m| m.to_f32())
|
.and_then(|m| m.to_f32())
|
||||||
.unwrap_or(10000f32);
|
.unwrap_or(10000f32);
|
||||||
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?;
|
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?;
|
||||||
|
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
|
||||||
|
|
||||||
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
|
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
|
||||||
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
||||||
@ -455,6 +458,7 @@ impl ModelWeights {
|
|||||||
head_dim: embedding_length / head_count,
|
head_dim: embedding_length / head_count,
|
||||||
cos: cos.clone(),
|
cos: cos.clone(),
|
||||||
sin: sin.clone(),
|
sin: sin.clone(),
|
||||||
|
neg_inf: neg_inf.clone(),
|
||||||
kv_cache: None,
|
kv_cache: None,
|
||||||
span_attn,
|
span_attn,
|
||||||
span_rot,
|
span_rot,
|
||||||
|
Reference in New Issue
Block a user