mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Avoid tensor copying in the quantized example. (#1770)
This commit is contained in:
@ -157,16 +157,16 @@ struct LayerWeights {
|
||||
head_dim: usize,
|
||||
cos: Tensor,
|
||||
sin: Tensor,
|
||||
neg_inf: Tensor,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
span_attn: tracing::Span,
|
||||
span_rot: tracing::Span,
|
||||
span_mlp: tracing::Span,
|
||||
}
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
@ -240,7 +240,7 @@ impl LayerWeights {
|
||||
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let mask = mask.broadcast_as(att.shape())?;
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = masked_fill(&att, &mask, &self.neg_inf)?;
|
||||
let att = candle_nn::ops::softmax_last_dim(&att)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = att.matmul(&v.contiguous()?)?;
|
||||
@ -298,6 +298,7 @@ impl ModelWeights {
|
||||
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
|
||||
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
|
||||
let (cos, sin) = precomput_freqs_cis(head_dim, 10000., &ct.device)?;
|
||||
let neg_inf = Tensor::new(f32::NEG_INFINITY, &ct.device)?;
|
||||
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(&ct.device)?;
|
||||
let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
|
||||
@ -337,6 +338,7 @@ impl ModelWeights {
|
||||
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
|
||||
cos: cos.clone(),
|
||||
sin: sin.clone(),
|
||||
neg_inf: neg_inf.clone(),
|
||||
kv_cache: None,
|
||||
span_attn,
|
||||
span_rot,
|
||||
@ -385,6 +387,7 @@ impl ModelWeights {
|
||||
.and_then(|m| m.to_f32())
|
||||
.unwrap_or(10000f32);
|
||||
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?;
|
||||
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
|
||||
|
||||
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
||||
@ -455,6 +458,7 @@ impl ModelWeights {
|
||||
head_dim: embedding_length / head_count,
|
||||
cos: cos.clone(),
|
||||
sin: sin.clone(),
|
||||
neg_inf: neg_inf.clone(),
|
||||
kv_cache: None,
|
||||
span_attn,
|
||||
span_rot,
|
||||
|
Reference in New Issue
Block a user