mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add a slice_set op. (#2193)
* Add a slice_set op. * Add some testing. * Add the dedicated kv-cache module. * Derive debug and clone. * Expose more kv-cache functions. * Return the current data when appending. * Use the new cache in the quantized phi3 model.
This commit is contained in:
@ -235,4 +235,66 @@ impl Tensor {
|
||||
}
|
||||
Ok(crate::tensor::from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
/// Set the values on `self` using values from `src`. The copy starts at the specified
|
||||
/// `offset` for the target dimension `dim` on `self`.
|
||||
/// `self` and `src` must have the same shape except on dimension `dim` where the `self` size
|
||||
/// has to be greater than or equal to `offset` plus the `src` size.
|
||||
///
|
||||
/// Note that this modifies `self` in place and as such is not compatibel with
|
||||
/// back-propagation.
|
||||
pub fn slice_set<D: Dim>(&self, src: &Self, dim: D, offset: usize) -> Result<()> {
|
||||
let dim = dim.to_index(self.shape(), "slice-set")?;
|
||||
if !self.is_contiguous() || !src.is_contiguous() {
|
||||
Err(Error::RequiresContiguous { op: "slice-set" }.bt())?
|
||||
}
|
||||
if self.dtype() != src.dtype() {
|
||||
Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: self.dtype(),
|
||||
rhs: src.dtype(),
|
||||
op: "slice-set",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if self.device().location() != src.device().location() {
|
||||
Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: self.device().location(),
|
||||
rhs: src.device().location(),
|
||||
op: "slice-set",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if self.rank() != src.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: self.rank(),
|
||||
got: src.rank(),
|
||||
shape: self.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
for (dim_idx, (v1, v2)) in self.dims().iter().zip(src.dims().iter()).enumerate() {
|
||||
if dim_idx == dim && *v2 + offset > *v1 {
|
||||
crate::bail!("shape mismatch on target dim, dst: {v1}, src: {v2} + {offset}")
|
||||
}
|
||||
if dim_idx != dim && v1 != v2 {
|
||||
crate::bail!("shape mismatch on dim {dim_idx}, {v1} <> {v2}")
|
||||
}
|
||||
}
|
||||
let block_size: usize = src.dims().iter().skip(1 + dim).product();
|
||||
let d1: usize = src.dims().iter().take(dim).product();
|
||||
let d2 = block_size * src.dims()[dim];
|
||||
let dst_o = self.layout().start_offset() + offset * block_size;
|
||||
let src_o = src.layout().start_offset();
|
||||
src.storage().copy2d(
|
||||
&mut self.storage_mut(),
|
||||
d1,
|
||||
d2,
|
||||
/* src_s */ d2,
|
||||
/* dst_s */ block_size * self.dims()[dim],
|
||||
src_o,
|
||||
dst_o,
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -665,6 +665,30 @@ fn broadcast(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn slice_set(device: &Device) -> Result<()> {
|
||||
let (b, h, max_t, d) = (2, 4, 7, 3);
|
||||
let cache = Tensor::zeros((b, h, max_t, d), DType::F32, device)?;
|
||||
let tensor = Tensor::randn(0f32, 1f32, (b, h, 4, d), device)?;
|
||||
cache.slice_set(&tensor, 2, 0)?;
|
||||
let cache_t = cache.narrow(2, 0, 4)?;
|
||||
let diff = (cache_t - &tensor)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
cache.slice_set(&tensor, 2, 1)?;
|
||||
let cache_t = cache.narrow(2, 1, 4)?;
|
||||
let diff = (cache_t - &tensor)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
let ones = Tensor::ones((b, h, 1, d), DType::F32, device)?;
|
||||
cache.slice_set(&ones, 2, 6)?;
|
||||
let diff = cache.narrow(2, 5, 1)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
let diff = (cache.narrow(2, 6, 1)? - 1.)?
|
||||
.abs()?
|
||||
.sum_all()?
|
||||
.to_vec0::<f32>()?;
|
||||
assert_eq!(diff, 0.);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cat(device: &Device) -> Result<()> {
|
||||
// 1D
|
||||
let t1 = Tensor::new(&[3f32, 1., 4.], device)?;
|
||||
@ -1146,6 +1170,7 @@ test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
|
||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
|
||||
test_device!(narrow, narrow_cpu, narrow_gpu, narrow_metal);
|
||||
test_device!(broadcast, broadcast_cpu, broadcast_gpu, broadcast_metal);
|
||||
test_device!(slice_set, ss_cpu, ss_gpu, ss_metal);
|
||||
test_device!(cat, cat_cpu, cat_gpu, cat_metal);
|
||||
test_device!(sum, sum_cpu, sum_gpu, sum_metal);
|
||||
test_device!(min, min_cpu, min_gpu, min_metal);
|
||||
|
@ -213,7 +213,7 @@ fn main() -> anyhow::Result<()> {
|
||||
);
|
||||
match args.which {
|
||||
Which::Phi2 => Model::Phi2(Phi2::from_gguf(model, &mut file, &device)?),
|
||||
Which::Phi3 => Model::Phi3(Phi3::from_gguf(model, &mut file, &device)?),
|
||||
Which::Phi3 => Model::Phi3(Phi3::from_gguf(1, model, &mut file, &device)?),
|
||||
Which::Phi3b => Model::Phi3b(Phi3b::from_gguf(model, &mut file, &device)?),
|
||||
}
|
||||
};
|
||||
|
101
candle-nn/src/kv_cache.rs
Normal file
101
candle-nn/src/kv_cache.rs
Normal file
@ -0,0 +1,101 @@
|
||||
use candle::{DType, Device, Result, Shape, Tensor};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Cache {
|
||||
all_data: Tensor,
|
||||
dim: usize,
|
||||
current_seq_len: usize,
|
||||
max_seq_len: usize,
|
||||
}
|
||||
|
||||
impl Cache {
|
||||
pub fn new<S: Into<Shape>, D: candle::shape::Dim>(
|
||||
dim: D,
|
||||
shape: S,
|
||||
dtype: DType,
|
||||
dev: &Device,
|
||||
) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
let dim = dim.to_index(&shape, "kv-cache")?;
|
||||
let max_seq_len = shape.dims()[dim];
|
||||
let all_data = Tensor::zeros(shape, dtype, dev)?;
|
||||
Ok(Self {
|
||||
all_data,
|
||||
dim,
|
||||
current_seq_len: 0,
|
||||
max_seq_len,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn dim(&self) -> usize {
|
||||
self.dim
|
||||
}
|
||||
|
||||
pub fn current_seq_len(&self) -> usize {
|
||||
self.current_seq_len
|
||||
}
|
||||
|
||||
pub fn max_seq_len(&self) -> usize {
|
||||
self.max_seq_len
|
||||
}
|
||||
|
||||
pub fn all_data(&self) -> &Tensor {
|
||||
&self.all_data
|
||||
}
|
||||
|
||||
pub fn current_data(&self) -> Result<Tensor> {
|
||||
self.all_data.narrow(self.dim, 0, self.current_seq_len)
|
||||
}
|
||||
|
||||
pub fn append(&mut self, src: &Tensor) -> Result<()> {
|
||||
let seq_len = src.dim(self.dim)?;
|
||||
if self.current_seq_len + seq_len > self.max_seq_len {
|
||||
candle::bail!(
|
||||
"kv-cache: above max-seq-len {}+{seq_len}>{}",
|
||||
self.current_seq_len,
|
||||
self.max_seq_len
|
||||
)
|
||||
}
|
||||
self.all_data
|
||||
.slice_set(src, self.dim, self.current_seq_len)?;
|
||||
self.current_seq_len += seq_len;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KvCache {
|
||||
k: Cache,
|
||||
v: Cache,
|
||||
}
|
||||
|
||||
impl KvCache {
|
||||
pub fn new<S: Into<Shape>, D: candle::shape::Dim>(
|
||||
dim: D,
|
||||
shape: S,
|
||||
dtype: DType,
|
||||
dev: &Device,
|
||||
) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
let dim = dim.to_index(&shape, "kv-cache")?;
|
||||
let k = Cache::new(dim, &shape, dtype, dev)?;
|
||||
let v = Cache::new(dim, &shape, dtype, dev)?;
|
||||
Ok(Self { k, v })
|
||||
}
|
||||
|
||||
pub fn k(&self) -> Result<Tensor> {
|
||||
self.k.current_data()
|
||||
}
|
||||
|
||||
pub fn v(&self) -> Result<Tensor> {
|
||||
self.v.current_data()
|
||||
}
|
||||
|
||||
pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||
self.k.append(k)?;
|
||||
self.v.append(v)?;
|
||||
let k = self.k.current_data()?;
|
||||
let v = self.v.current_data()?;
|
||||
Ok((k, v))
|
||||
}
|
||||
}
|
@ -6,6 +6,7 @@ pub mod encoding;
|
||||
pub mod func;
|
||||
pub mod group_norm;
|
||||
pub mod init;
|
||||
pub mod kv_cache;
|
||||
pub mod layer_norm;
|
||||
pub mod linear;
|
||||
pub mod loss;
|
||||
|
@ -3,9 +3,7 @@ use std::collections::HashMap;
|
||||
use candle::quantized::gguf_file;
|
||||
use candle::quantized::QTensor;
|
||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::{Embedding, RmsNorm};
|
||||
|
||||
pub const MAX_SEQ_LEN: usize = 4096;
|
||||
use candle_nn::{kv_cache::KvCache, Embedding, RmsNorm};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct QLinear {
|
||||
@ -70,7 +68,7 @@ struct LayerWeights {
|
||||
cos: Tensor,
|
||||
sin: Tensor,
|
||||
neg_inf: Tensor,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
kv_cache: KvCache,
|
||||
span_attn: tracing::Span,
|
||||
span_rot: tracing::Span,
|
||||
}
|
||||
@ -122,19 +120,7 @@ impl LayerWeights {
|
||||
let q = self.apply_rotary_emb(&q, index_pos)?.contiguous()?;
|
||||
let k = self.apply_rotary_emb(&k, index_pos)?;
|
||||
|
||||
let (k, v) = match &self.kv_cache {
|
||||
None => (k.contiguous()?, v.contiguous()?),
|
||||
Some((k_cache, v_cache)) => {
|
||||
if index_pos == 0 {
|
||||
(k.contiguous()?, v.contiguous()?)
|
||||
} else {
|
||||
let k = Tensor::cat(&[k_cache, &k], 2)?;
|
||||
let v = Tensor::cat(&[v_cache, &v], 2)?;
|
||||
(k.contiguous()?, v.contiguous()?)
|
||||
}
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((k.clone(), v.clone()));
|
||||
let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?;
|
||||
|
||||
let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
|
||||
let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;
|
||||
@ -169,6 +155,7 @@ pub struct ModelWeights {
|
||||
|
||||
fn precomput_freqs_cis(
|
||||
head_dim: usize,
|
||||
max_seq_len: usize,
|
||||
freq_base: f32,
|
||||
device: &Device,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
@ -177,9 +164,9 @@ fn precomput_freqs_cis(
|
||||
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
|
||||
.collect();
|
||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
||||
let idx_theta = Tensor::arange(0, max_seq_len as u32, device)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((MAX_SEQ_LEN, 1))?
|
||||
.reshape((max_seq_len, 1))?
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
let cos = idx_theta.cos()?;
|
||||
let sin = idx_theta.sin()?;
|
||||
@ -188,6 +175,7 @@ fn precomput_freqs_cis(
|
||||
|
||||
impl ModelWeights {
|
||||
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
|
||||
batch_size: usize,
|
||||
ct: gguf_file::Content,
|
||||
reader: &mut R,
|
||||
device: &Device,
|
||||
@ -202,16 +190,19 @@ impl ModelWeights {
|
||||
let head_count_kv = md_get("phi3.attention.head_count_kv")?.to_u32()? as usize;
|
||||
let block_count = md_get("phi3.block_count")?.to_u32()? as usize;
|
||||
let embedding_length = md_get("phi3.embedding_length")?.to_u32()? as usize;
|
||||
let max_seq_len = md_get("phi3.context_length")?.to_u32()? as usize;
|
||||
let head_dim = embedding_length / head_count;
|
||||
let i_size = md_get("phi3.feed_forward_length")?.to_u32()? as usize;
|
||||
let rope_dim = md_get("phi3.rope.dimension_count")?.to_u32()? as usize;
|
||||
let rms_eps = md_get("phi3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
|
||||
let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device)?;
|
||||
let (cos, sin) = precomput_freqs_cis(rope_dim, max_seq_len, 10_000., device)?;
|
||||
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
|
||||
|
||||
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
||||
let output_norm = rms_norm(ct.tensor(reader, "output_norm.weight", device)?, rms_eps)?;
|
||||
let output = QLinear::new(&ct, reader, "output", device)?;
|
||||
|
||||
let mut layers = Vec::with_capacity(block_count);
|
||||
for layer_idx in 0..block_count {
|
||||
let prefix = format!("blk.{layer_idx}");
|
||||
@ -232,6 +223,12 @@ impl ModelWeights {
|
||||
)?;
|
||||
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
||||
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||
let kv_cache = KvCache::new(
|
||||
2,
|
||||
(batch_size, head_count_kv, max_seq_len, head_dim),
|
||||
DType::F32,
|
||||
device,
|
||||
)?;
|
||||
layers.push(LayerWeights {
|
||||
attn_qkv: QLinear::new(&ct, reader, &format!("{prefix}.attn_qkv"), device)?,
|
||||
attn_output: QLinear::new(&ct, reader, &format!("{prefix}.attn_output"), device)?,
|
||||
@ -240,11 +237,11 @@ impl ModelWeights {
|
||||
mlp,
|
||||
n_head: head_count,
|
||||
n_kv_head: head_count_kv,
|
||||
head_dim: embedding_length / head_count,
|
||||
head_dim,
|
||||
cos: cos.clone(),
|
||||
sin: sin.clone(),
|
||||
neg_inf: neg_inf.clone(),
|
||||
kv_cache: None,
|
||||
kv_cache,
|
||||
span_attn,
|
||||
span_rot,
|
||||
})
|
||||
|
Reference in New Issue
Block a user