Compare commits

..

11 Commits

10 changed files with 52 additions and 109 deletions

View File

@ -20,7 +20,7 @@ exclude = [
resolver = "2" resolver = "2"
[workspace.package] [workspace.package]
version = "0.7.1" version = "0.7.0"
edition = "2021" edition = "2021"
description = "Minimalist ML framework." description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle" repository = "https://github.com/huggingface/candle"
@ -33,14 +33,14 @@ ab_glyph = "0.2.23"
accelerate-src = { version = "0.3.2" } accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] } anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3" byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.7.1" } candle = { path = "./candle-core", package = "candle-core", version = "0.7.0" }
candle-datasets = { path = "./candle-datasets", version = "0.7.1" } candle-datasets = { path = "./candle-datasets", version = "0.7.0" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.7.1" } candle-flash-attn = { path = "./candle-flash-attn", version = "0.7.0" }
candle-kernels = { path = "./candle-kernels", version = "0.7.1" } candle-kernels = { path = "./candle-kernels", version = "0.7.0" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.7.1" } candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.7.0" }
candle-nn = { path = "./candle-nn", version = "0.7.1" } candle-nn = { path = "./candle-nn", version = "0.7.0" }
candle-onnx = { path = "./candle-onnx", version = "0.7.1" } candle-onnx = { path = "./candle-onnx", version = "0.7.0" }
candle-transformers = { path = "./candle-transformers", version = "0.7.1" } candle-transformers = { path = "./candle-transformers", version = "0.7.0" }
clap = { version = "4.2.4", features = ["derive"] } clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false } criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.12.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false } cudarc = { version = "0.12.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }

View File

@ -37,12 +37,6 @@ fn pad(p: usize, q: usize) -> usize {
ceil_div(p, q) * q ceil_div(p, q) * q
} }
fn pad_for_alloc(p: usize) -> usize {
// Overallocate by q rather than just padding by q as this should pad the last row
// and we don't have enough information here to know how many elements to add :(
p + MATRIX_ROW_PADDING
}
fn quantize_q8_1( fn quantize_q8_1(
src: &CudaView<f32>, src: &CudaView<f32>,
dst: &mut CudaSlice<u8>, dst: &mut CudaSlice<u8>,
@ -450,11 +444,8 @@ impl QCudaStorage {
let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?; let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?;
qcpu_storage.quantize(&src)?; qcpu_storage.quantize(&src)?;
let data = qcpu_storage.data()?; let data = qcpu_storage.data()?;
let mut dst = self.device.alloc_zeros::<u8>(pad_for_alloc(src_len)).w()?; let data = self.device.htod_sync_copy(data.as_ref()).w()?;
self.device self.data = data;
.htod_sync_copy_into(data.as_ref(), &mut dst.slice_mut(..src_len))
.w()?;
self.data = dst;
Ok(()) Ok(())
} }

View File

@ -18,7 +18,7 @@ pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>(
let actual_blocks = ys.len(); let actual_blocks = ys.len();
// Validate that the input is the right size // Validate that the input is the right size
if expected_blocks > actual_blocks { if expected_blocks != actual_blocks {
crate::bail!("quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!") crate::bail!("quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!")
} }

View File

@ -1,6 +1,6 @@
[package] [package]
name = "candle-flash-attn" name = "candle-flash-attn"
version = "0.7.1" version = "0.7.0"
edition = "2021" edition = "2021"
description = "Flash attention layer for the candle ML framework." description = "Flash attention layer for the candle ML framework."
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
readme = "README.md" readme = "README.md"
[dependencies] [dependencies]
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.7.1" } candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.7.0" }
half = { version = "2.3.1", features = ["num-traits"] } half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies] [build-dependencies]

View File

@ -1,6 +1,6 @@
[package] [package]
name = "candle-kernels" name = "candle-kernels"
version = "0.7.1" version = "0.7.0"
edition = "2021" edition = "2021"
description = "CUDA kernels for Candle" description = "CUDA kernels for Candle"

View File

@ -1,6 +1,6 @@
[package] [package]
name = "candle-metal-kernels" name = "candle-metal-kernels"
version = "0.7.1" version = "0.7.0"
edition = "2021" edition = "2021"
description = "Metal kernels for Candle" description = "Metal kernels for Candle"

View File

@ -1,4 +1,4 @@
use candle::{Device, Result, Tensor}; use candle::{Result, Tensor};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Cache { pub struct Cache {
@ -255,56 +255,6 @@ impl RotatingCache {
} }
} }
} }
fn get_mask_abs(&self, size1: usize, size2: usize, device: &Device) -> Result<Tensor> {
let context = self.max_seq_len;
let mask: Vec<_> = (0..size1)
.flat_map(|i| {
(0..size2).map(move |j| {
u8::from(size1 + j > size2 + i || size1 + j + context < size2 + i)
})
})
.collect();
Tensor::from_slice(&mask, (size1, size2), device)
}
fn get_mask_rel(&self, size1: usize, size2: usize, device: &Device) -> Result<Tensor> {
let context = self.max_seq_len;
let upd_offset = (self.offset + size1) % self.max_seq_len;
let mask: Vec<_> = (0..size1)
.flat_map(|pos_src| {
// The absolute position of the elements that will get added to the cache.
let pos_src = self.current_seq_len + pos_src;
(0..size2).map(move |pos_cache_rel| {
// The absolute position of the cache elements after the addition.
let pos_cache = self.current_seq_len + size1 + pos_cache_rel - upd_offset;
let pos_cache = if pos_cache_rel < upd_offset {
pos_cache
} else {
pos_cache - self.max_seq_len
};
u8::from(pos_cache > pos_src || pos_cache + context < pos_src)
})
})
.collect();
Tensor::from_slice(&mask, (size1, size2), device)
}
/// Returns the attn_mask to be applied *after* adding `seq_len` to the cache.
pub fn attn_mask(&self, seq_len: usize, device: &Device) -> Result<Option<Tensor>> {
let mask = if seq_len == 1 {
None
} else {
let mask = if seq_len < self.max_seq_len {
let cache_out_len = (self.current_seq_len + seq_len).min(self.max_seq_len);
self.get_mask_rel(seq_len, cache_out_len, device)?
} else {
self.get_mask_abs(seq_len, seq_len, device)?
};
Some(mask)
};
Ok(mask)
}
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -358,10 +308,6 @@ impl RotatingKvCache {
self.k.current_seq_len() self.k.current_seq_len()
} }
pub fn attn_mask(&self, seq_len: usize, device: &Device) -> Result<Option<Tensor>> {
self.k.attn_mask(seq_len, device)
}
pub fn reset(&mut self) { pub fn reset(&mut self) {
self.k.reset(); self.k.reset();
self.v.reset(); self.v.reset();

View File

@ -69,36 +69,13 @@ fn rotating_kv_cache() -> Result<()> {
assert_eq!(cache.current_seq_len(), 13); assert_eq!(cache.current_seq_len(), 13);
assert_eq!(cache.offset(), 1); assert_eq!(cache.offset(), 1);
let mask = cache.attn_mask(2, &Device::Cpu)?.unwrap();
assert_eq!(
mask.to_vec2::<u8>()?,
&[[0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, 0]]
);
let mask = cache.attn_mask(3, &Device::Cpu)?.unwrap();
assert_eq!(
mask.to_vec2::<u8>()?,
&[[0, 0, 1, 1, 0, 0], [0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 0]],
);
let t = Tensor::new(&[0., 1., 2., 3., 4., 5., 6., 7., 8.], &Device::Cpu)?; let t = Tensor::new(&[0., 1., 2., 3., 4., 5., 6., 7., 8.], &Device::Cpu)?;
let data = cache.append(&t)?; let data = cache.append(&t)?;
assert_eq!(data.to_vec1::<f64>()?, [0., 1., 2., 3., 4., 5., 6., 7., 8.]); assert_eq!(data.to_vec1::<f64>()?, [0., 1., 2., 3., 4., 5., 6., 7., 8.]);
assert_eq!(cache.current_seq_len(), 22); assert_eq!(cache.current_seq_len(), 22);
assert_eq!(cache.offset(), 0); assert_eq!(cache.offset(), 0);
let mask = cache.attn_mask(1, &Device::Cpu)?;
assert!(mask.is_none());
let mask = cache.attn_mask(2, &Device::Cpu)?.unwrap();
assert_eq!(
mask.to_vec2::<u8>()?,
&[[0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]]
);
let mask = cache.attn_mask(3, &Device::Cpu)?.unwrap();
assert_eq!(
mask.to_vec2::<u8>()?,
&[[0, 1, 1, 0, 0, 0], [0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, 0]]
);
let t = Tensor::new(&[42.], &Device::Cpu)?; let t = Tensor::new(&[42.], &Device::Cpu)?;
let data = cache.append(&t)?; let data = cache.append(&t)?;
assert_eq!(data.to_vec1::<f64>()?, [42., 4., 5., 6., 7., 8.]); assert_eq!(data.to_vec1::<f64>()?, [42., 4., 5., 6., 7., 8.]);
assert_eq!(cache.current_seq_len(), 23); assert_eq!(cache.current_seq_len(), 23);

View File

@ -1,6 +1,6 @@
[package] [package]
name = "candle-onnx" name = "candle-onnx"
version = "0.7.1" version = "0.7.0"
edition = "2021" edition = "2021"
description = "ONNX support for Candle" description = "ONNX support for Candle"
@ -10,8 +10,8 @@ categories = ["science"]
license = "MIT OR Apache-2.0" license = "MIT OR Apache-2.0"
[dependencies] [dependencies]
candle = { path = "../candle-core", package = "candle-core", version = "0.7.1" } candle = { path = "../candle-core", package = "candle-core", version = "0.7.0" }
candle-nn = { path = "../candle-nn", version = "0.7.1" } candle-nn = { path = "../candle-nn", version = "0.7.0" }
prost = "0.12.1" prost = "0.12.1"
[build-dependencies] [build-dependencies]

View File

@ -101,6 +101,21 @@ impl Module for LayerScale {
} }
} }
pub(crate) fn get_mask(
size1: usize,
size2: usize,
context: usize,
device: &Device,
) -> Result<Tensor> {
let mask: Vec<_> = (0..size1)
.flat_map(|i| {
(0..size2)
.map(move |j| u8::from(size1 + j > size2 + i || size1 + j + context < size2 + i))
})
.collect();
Tensor::from_slice(&mask, (size1, size2), device)
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct StreamingMultiheadAttention { pub struct StreamingMultiheadAttention {
q_proj: Linear, q_proj: Linear,
@ -575,6 +590,7 @@ impl StreamingTransformerLayer {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct StreamingTransformer { pub struct StreamingTransformer {
layers: Vec<StreamingTransformerLayer>, layers: Vec<StreamingTransformerLayer>,
context: usize,
positional_embedding: PositionalEmbedding, positional_embedding: PositionalEmbedding,
max_period: usize, max_period: usize,
} }
@ -601,6 +617,7 @@ impl StreamingTransformer {
} }
Ok(Self { Ok(Self {
layers, layers,
context: cfg.context,
positional_embedding: cfg.positional_embedding, positional_embedding: cfg.positional_embedding,
max_period: cfg.max_period, max_period: cfg.max_period,
}) })
@ -612,11 +629,23 @@ impl StreamingTransformer {
pub fn forward_ca(&mut self, xs: &Tensor, ca_src: Option<&Tensor>) -> Result<Tensor> { pub fn forward_ca(&mut self, xs: &Tensor, ca_src: Option<&Tensor>) -> Result<Tensor> {
let (_b, t, c) = xs.dims3()?; let (_b, t, c) = xs.dims3()?;
let pos = self.layers[0].self_attn.kv_cache.current_seq_len(); let pos = self.layers[0]
let mask = self.layers[0]
.self_attn .self_attn
.kv_cache .kv_cache
.attn_mask(t, xs.device())?; .k_cache()
.current_seq_len();
let mask = if t == 1 {
None
} else {
let cache_out_len = if t < self.context {
(pos + t).min(self.context)
} else {
t
};
// TODO: this is wrong, the mask depends on the kv-cache offset because of its rotating
// nature.
Some(get_mask(t, cache_out_len, self.context, xs.device())?)
};
let mut xs = match self.positional_embedding { let mut xs = match self.positional_embedding {
PositionalEmbedding::Rope | PositionalEmbedding::None => xs.clone(), PositionalEmbedding::Rope | PositionalEmbedding::None => xs.clone(),
PositionalEmbedding::Sin => { PositionalEmbedding::Sin => {