mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Add a slice_set op. (#2193)
* Add a slice_set op. * Add some testing. * Add the dedicated kv-cache module. * Derive debug and clone. * Expose more kv-cache functions. * Return the current data when appending. * Use the new cache in the quantized phi3 model.
This commit is contained in:
@ -3,9 +3,7 @@ use std::collections::HashMap;
|
||||
use candle::quantized::gguf_file;
|
||||
use candle::quantized::QTensor;
|
||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::{Embedding, RmsNorm};
|
||||
|
||||
pub const MAX_SEQ_LEN: usize = 4096;
|
||||
use candle_nn::{kv_cache::KvCache, Embedding, RmsNorm};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct QLinear {
|
||||
@ -70,7 +68,7 @@ struct LayerWeights {
|
||||
cos: Tensor,
|
||||
sin: Tensor,
|
||||
neg_inf: Tensor,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
kv_cache: KvCache,
|
||||
span_attn: tracing::Span,
|
||||
span_rot: tracing::Span,
|
||||
}
|
||||
@ -122,19 +120,7 @@ impl LayerWeights {
|
||||
let q = self.apply_rotary_emb(&q, index_pos)?.contiguous()?;
|
||||
let k = self.apply_rotary_emb(&k, index_pos)?;
|
||||
|
||||
let (k, v) = match &self.kv_cache {
|
||||
None => (k.contiguous()?, v.contiguous()?),
|
||||
Some((k_cache, v_cache)) => {
|
||||
if index_pos == 0 {
|
||||
(k.contiguous()?, v.contiguous()?)
|
||||
} else {
|
||||
let k = Tensor::cat(&[k_cache, &k], 2)?;
|
||||
let v = Tensor::cat(&[v_cache, &v], 2)?;
|
||||
(k.contiguous()?, v.contiguous()?)
|
||||
}
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((k.clone(), v.clone()));
|
||||
let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?;
|
||||
|
||||
let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
|
||||
let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;
|
||||
@ -169,6 +155,7 @@ pub struct ModelWeights {
|
||||
|
||||
fn precomput_freqs_cis(
|
||||
head_dim: usize,
|
||||
max_seq_len: usize,
|
||||
freq_base: f32,
|
||||
device: &Device,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
@ -177,9 +164,9 @@ fn precomput_freqs_cis(
|
||||
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
|
||||
.collect();
|
||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
||||
let idx_theta = Tensor::arange(0, max_seq_len as u32, device)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((MAX_SEQ_LEN, 1))?
|
||||
.reshape((max_seq_len, 1))?
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
let cos = idx_theta.cos()?;
|
||||
let sin = idx_theta.sin()?;
|
||||
@ -188,6 +175,7 @@ fn precomput_freqs_cis(
|
||||
|
||||
impl ModelWeights {
|
||||
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
|
||||
batch_size: usize,
|
||||
ct: gguf_file::Content,
|
||||
reader: &mut R,
|
||||
device: &Device,
|
||||
@ -202,16 +190,19 @@ impl ModelWeights {
|
||||
let head_count_kv = md_get("phi3.attention.head_count_kv")?.to_u32()? as usize;
|
||||
let block_count = md_get("phi3.block_count")?.to_u32()? as usize;
|
||||
let embedding_length = md_get("phi3.embedding_length")?.to_u32()? as usize;
|
||||
let max_seq_len = md_get("phi3.context_length")?.to_u32()? as usize;
|
||||
let head_dim = embedding_length / head_count;
|
||||
let i_size = md_get("phi3.feed_forward_length")?.to_u32()? as usize;
|
||||
let rope_dim = md_get("phi3.rope.dimension_count")?.to_u32()? as usize;
|
||||
let rms_eps = md_get("phi3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
|
||||
let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device)?;
|
||||
let (cos, sin) = precomput_freqs_cis(rope_dim, max_seq_len, 10_000., device)?;
|
||||
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
|
||||
|
||||
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
||||
let output_norm = rms_norm(ct.tensor(reader, "output_norm.weight", device)?, rms_eps)?;
|
||||
let output = QLinear::new(&ct, reader, "output", device)?;
|
||||
|
||||
let mut layers = Vec::with_capacity(block_count);
|
||||
for layer_idx in 0..block_count {
|
||||
let prefix = format!("blk.{layer_idx}");
|
||||
@ -232,6 +223,12 @@ impl ModelWeights {
|
||||
)?;
|
||||
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
||||
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||
let kv_cache = KvCache::new(
|
||||
2,
|
||||
(batch_size, head_count_kv, max_seq_len, head_dim),
|
||||
DType::F32,
|
||||
device,
|
||||
)?;
|
||||
layers.push(LayerWeights {
|
||||
attn_qkv: QLinear::new(&ct, reader, &format!("{prefix}.attn_qkv"), device)?,
|
||||
attn_output: QLinear::new(&ct, reader, &format!("{prefix}.attn_output"), device)?,
|
||||
@ -240,11 +237,11 @@ impl ModelWeights {
|
||||
mlp,
|
||||
n_head: head_count,
|
||||
n_kv_head: head_count_kv,
|
||||
head_dim: embedding_length / head_count,
|
||||
head_dim,
|
||||
cos: cos.clone(),
|
||||
sin: sin.clone(),
|
||||
neg_inf: neg_inf.clone(),
|
||||
kv_cache: None,
|
||||
kv_cache,
|
||||
span_attn,
|
||||
span_rot,
|
||||
})
|
||||
|
Reference in New Issue
Block a user