mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Compare commits
11 Commits
0.7.1
...
cudarc-12-
Author | SHA1 | Date | |
---|---|---|---|
42c702a023 | |||
d6f01f625d | |||
3277844fd9 | |||
c79bf421c7 | |||
58c1e909d3 | |||
9964c6d86c | |||
fc877920ce | |||
6547c4bfc3 | |||
f9579f80be | |||
1bddd44cb8 | |||
9cfe3c7141 |
@ -43,7 +43,7 @@ candle-onnx = { path = "./candle-onnx", version = "0.7.0" }
|
|||||||
candle-transformers = { path = "./candle-transformers", version = "0.7.0" }
|
candle-transformers = { path = "./candle-transformers", version = "0.7.0" }
|
||||||
clap = { version = "4.2.4", features = ["derive"] }
|
clap = { version = "4.2.4", features = ["derive"] }
|
||||||
criterion = { version = "0.5.1", default-features=false }
|
criterion = { version = "0.5.1", default-features=false }
|
||||||
cudarc = { version = "0.12.0", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
cudarc = { version = "0.12.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||||
fancy-regex = "0.13.0"
|
fancy-regex = "0.13.0"
|
||||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||||
hf-hub = "0.3.0"
|
hf-hub = "0.3.0"
|
||||||
|
@ -39,6 +39,11 @@ struct Args {
|
|||||||
/// The model weight file, in safetensor format.
|
/// The model weight file, in safetensor format.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model: Option<String>,
|
model: Option<String>,
|
||||||
|
|
||||||
|
/// Whether to use streaming or not, when streaming slices of data of the given size are passed
|
||||||
|
/// to the encoder/decoder one at a time.
|
||||||
|
#[arg(long)]
|
||||||
|
streaming: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
@ -87,20 +92,49 @@ fn main() -> Result<()> {
|
|||||||
pcm
|
pcm
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let pcm_len = pcm.len();
|
match args.streaming {
|
||||||
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
|
Some(chunk_size) => {
|
||||||
println!("input pcm shape: {:?}", pcm.shape());
|
let mut code_chunks = vec![];
|
||||||
model.encode(&pcm)?
|
for pcm in pcm.chunks(chunk_size) {
|
||||||
|
let pcm = Tensor::new(pcm, &device)?.reshape((1, 1, ()))?;
|
||||||
|
let code_chunk = model.encode(&pcm)?;
|
||||||
|
code_chunks.push(code_chunk)
|
||||||
|
}
|
||||||
|
Tensor::cat(&code_chunks, candle::D::Minus1)?
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
let pcm_len = pcm.len();
|
||||||
|
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
|
||||||
|
println!("input pcm shape: {:?}", pcm.shape());
|
||||||
|
model.encode(&pcm)?
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
println!("codes shape: {:?}", codes.shape());
|
println!("codes shape: {:?}", codes.shape());
|
||||||
|
model.reset_state();
|
||||||
|
|
||||||
match args.action {
|
match args.action {
|
||||||
Action::AudioToCode => {
|
Action::AudioToCode => {
|
||||||
codes.save_safetensors("codes", &args.out_file)?;
|
codes.save_safetensors("codes", &args.out_file)?;
|
||||||
}
|
}
|
||||||
Action::AudioToAudio | Action::CodeToAudio => {
|
Action::AudioToAudio | Action::CodeToAudio => {
|
||||||
let pcm = model.decode(&codes)?;
|
let pcm = match args.streaming {
|
||||||
|
Some(chunk_size) => {
|
||||||
|
let seq_len = codes.dim(candle::D::Minus1)?;
|
||||||
|
let mut pcm_chunks = vec![];
|
||||||
|
for chunk_start in (0..seq_len).step_by(chunk_size) {
|
||||||
|
let chunk_len = usize::min(chunk_size, seq_len - chunk_start);
|
||||||
|
let codes = codes.narrow(candle::D::Minus1, chunk_start, chunk_len)?;
|
||||||
|
let pcm = model.decode_step(&codes.into())?;
|
||||||
|
if let Some(pcm) = pcm.as_option() {
|
||||||
|
pcm_chunks.push(pcm.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Tensor::cat(&pcm_chunks, candle::D::Minus1)?
|
||||||
|
}
|
||||||
|
None => model.decode(&codes)?,
|
||||||
|
};
|
||||||
println!("output pcm shape: {:?}", pcm.shape());
|
println!("output pcm shape: {:?}", pcm.shape());
|
||||||
let pcm = pcm.i(0)?.i(0)?;
|
let pcm = pcm.i(0)?.i(0)?;
|
||||||
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
|
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
|
||||||
|
@ -145,3 +145,171 @@ impl KvCache {
|
|||||||
self.v.reset();
|
self.v.reset();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct RotatingCache {
|
||||||
|
all_data: Option<Tensor>,
|
||||||
|
dim: usize,
|
||||||
|
// `offset` is the current write index in the buffer
|
||||||
|
offset: usize,
|
||||||
|
// The total size of the sequence seen so far.
|
||||||
|
current_seq_len: usize,
|
||||||
|
// max_seq_len is the size of the rotating buffer, it is actually allowed for the full
|
||||||
|
// sequence to grow past this limit.
|
||||||
|
max_seq_len: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RotatingCache {
|
||||||
|
pub fn new(dim: usize, max_seq_len: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
all_data: None,
|
||||||
|
dim,
|
||||||
|
offset: 0,
|
||||||
|
current_seq_len: 0,
|
||||||
|
max_seq_len,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn offset(&self) -> usize {
|
||||||
|
self.offset
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dim(&self) -> usize {
|
||||||
|
self.dim
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn current_seq_len(&self) -> usize {
|
||||||
|
self.current_seq_len
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn max_seq_len(&self) -> usize {
|
||||||
|
self.max_seq_len
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn all_data(&self) -> &Option<Tensor> {
|
||||||
|
&self.all_data
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn current_data(&self) -> Result<Option<Tensor>> {
|
||||||
|
let data = match self.all_data.as_ref() {
|
||||||
|
None => None,
|
||||||
|
Some(d) => {
|
||||||
|
if self.current_seq_len >= self.max_seq_len {
|
||||||
|
Some(d.clone())
|
||||||
|
} else {
|
||||||
|
Some(d.narrow(self.dim, 0, self.current_seq_len)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reset(&mut self) {
|
||||||
|
self.offset = 0;
|
||||||
|
self.current_seq_len = 0;
|
||||||
|
self.all_data = None;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn append(&mut self, src: &Tensor) -> Result<Tensor> {
|
||||||
|
let seq_len = src.dim(self.dim)?;
|
||||||
|
// This doesn't seem very idiomatic but because the creation can fail, it's tricky to use
|
||||||
|
// self.all_data.get_or_insert_with.
|
||||||
|
if self.all_data.is_none() {
|
||||||
|
let mut shape = src.dims().to_vec();
|
||||||
|
shape[self.dim] = self.max_seq_len;
|
||||||
|
let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
|
||||||
|
self.all_data = Some(ad)
|
||||||
|
};
|
||||||
|
let ad = self.all_data.as_mut().unwrap();
|
||||||
|
|
||||||
|
self.current_seq_len += seq_len;
|
||||||
|
if seq_len >= self.max_seq_len {
|
||||||
|
let to_copy = src
|
||||||
|
.narrow(self.dim, seq_len - self.max_seq_len, self.max_seq_len)?
|
||||||
|
.contiguous()?;
|
||||||
|
ad.slice_set(&to_copy, self.dim, 0)?;
|
||||||
|
self.offset = 0;
|
||||||
|
// Here we return `src` rather than `ad` so that all the past can be used.
|
||||||
|
Ok(src.clone())
|
||||||
|
} else {
|
||||||
|
let rem_len = self.max_seq_len - self.offset;
|
||||||
|
if seq_len <= rem_len {
|
||||||
|
ad.slice_set(&src.contiguous()?, self.dim, self.offset)?;
|
||||||
|
self.offset = (self.offset + seq_len) % self.max_seq_len;
|
||||||
|
} else {
|
||||||
|
// We have to make two copies here as we go over the boundary of the cache.
|
||||||
|
if rem_len > 0 {
|
||||||
|
let src1 = src.narrow(self.dim, 0, rem_len)?.contiguous()?;
|
||||||
|
ad.slice_set(&src1, self.dim, self.offset)?;
|
||||||
|
}
|
||||||
|
let src2 = src
|
||||||
|
.narrow(self.dim, rem_len, seq_len - rem_len)?
|
||||||
|
.contiguous()?;
|
||||||
|
ad.slice_set(&src2, self.dim, 0)?;
|
||||||
|
self.offset = seq_len - rem_len;
|
||||||
|
}
|
||||||
|
if self.current_seq_len >= self.max_seq_len {
|
||||||
|
Ok(ad.clone())
|
||||||
|
} else {
|
||||||
|
Ok(ad.narrow(self.dim, 0, self.current_seq_len)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct RotatingKvCache {
|
||||||
|
k: RotatingCache,
|
||||||
|
v: RotatingCache,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RotatingKvCache {
|
||||||
|
pub fn new(dim: usize, max_seq_len: usize) -> Self {
|
||||||
|
let k = RotatingCache::new(dim, max_seq_len);
|
||||||
|
let v = RotatingCache::new(dim, max_seq_len);
|
||||||
|
Self { k, v }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn k_cache(&self) -> &RotatingCache {
|
||||||
|
&self.k
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn v_cache(&self) -> &RotatingCache {
|
||||||
|
&self.v
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn k_cache_mut(&mut self) -> &mut RotatingCache {
|
||||||
|
&mut self.k
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn v_cache_mut(&mut self) -> &mut RotatingCache {
|
||||||
|
&mut self.v
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn k(&self) -> Result<Option<Tensor>> {
|
||||||
|
self.k.current_data()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn v(&self) -> Result<Option<Tensor>> {
|
||||||
|
self.v.current_data()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||||
|
let out_k = self.k.append(k)?;
|
||||||
|
let out_v = self.v.append(v)?;
|
||||||
|
Ok((out_k, out_v))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn offset(&self) -> usize {
|
||||||
|
self.k.offset()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn current_seq_len(&self) -> usize {
|
||||||
|
self.k.current_seq_len()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reset(&mut self) {
|
||||||
|
self.k.reset();
|
||||||
|
self.v.reset();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
87
candle-nn/tests/kv_cache.rs
Normal file
87
candle-nn/tests/kv_cache.rs
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use candle::{Device, Result, Tensor};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn kv_cache() -> Result<()> {
|
||||||
|
let mut cache = candle_nn::kv_cache::Cache::new(0, 16);
|
||||||
|
for _ in [0, 1] {
|
||||||
|
assert_eq!(cache.current_seq_len(), 0);
|
||||||
|
let data = cache.current_data()?;
|
||||||
|
assert!(data.is_none());
|
||||||
|
let t = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?;
|
||||||
|
cache.append(&t)?;
|
||||||
|
let data = cache.current_data()?.unwrap();
|
||||||
|
assert_eq!(data.to_vec1::<f32>()?, [1., 2., 3.]);
|
||||||
|
let t = Tensor::new(&[4f32], &Device::Cpu)?;
|
||||||
|
cache.append(&t)?;
|
||||||
|
let data = cache.current_data()?.unwrap();
|
||||||
|
assert_eq!(data.to_vec1::<f32>()?, [1., 2., 3., 4.]);
|
||||||
|
let t = Tensor::new(&[0f32, 5., 6., 7.], &Device::Cpu)?;
|
||||||
|
cache.append(&t)?;
|
||||||
|
let data = cache.current_data()?.unwrap();
|
||||||
|
assert_eq!(data.to_vec1::<f32>()?, [1., 2., 3., 4., 0., 5., 6., 7.]);
|
||||||
|
assert_eq!(cache.current_seq_len(), 8);
|
||||||
|
cache.reset();
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rotating_kv_cache() -> Result<()> {
|
||||||
|
let mut cache = candle_nn::kv_cache::RotatingCache::new(0, 6);
|
||||||
|
for _ in [0, 1] {
|
||||||
|
assert_eq!(cache.offset(), 0);
|
||||||
|
assert_eq!(cache.current_seq_len(), 0);
|
||||||
|
let data = cache.current_data()?;
|
||||||
|
assert!(data.is_none());
|
||||||
|
let t = Tensor::new(&[1., 2., 3.], &Device::Cpu)?;
|
||||||
|
let data = cache.append(&t)?;
|
||||||
|
assert_eq!(data.to_vec1::<f64>()?, [1., 2., 3.]);
|
||||||
|
let t = Tensor::new(&[4.], &Device::Cpu)?;
|
||||||
|
let data = cache.append(&t)?;
|
||||||
|
assert_eq!(data.to_vec1::<f64>()?, [1., 2., 3., 4.]);
|
||||||
|
let t = Tensor::new(&[0., 5., 6., 7.], &Device::Cpu)?;
|
||||||
|
let data = cache.append(&t)?;
|
||||||
|
assert_eq!(data.to_vec1::<f64>()?, [6., 7., 3., 4., 0., 5.]);
|
||||||
|
assert_eq!(cache.current_seq_len(), 8);
|
||||||
|
assert_eq!(cache.offset(), 2);
|
||||||
|
|
||||||
|
let t = Tensor::new(&[8.], &Device::Cpu)?;
|
||||||
|
let data = cache.append(&t)?;
|
||||||
|
assert_eq!(data.to_vec1::<f64>()?, [6., 7., 8., 4., 0., 5.]);
|
||||||
|
assert_eq!(cache.current_seq_len(), 9);
|
||||||
|
assert_eq!(cache.offset(), 3);
|
||||||
|
|
||||||
|
let t = Tensor::new(&[9., 10., 11.], &Device::Cpu)?;
|
||||||
|
let data = cache.append(&t)?;
|
||||||
|
assert_eq!(data.to_vec1::<f64>()?, [6., 7., 8., 9., 10., 11.]);
|
||||||
|
assert_eq!(cache.current_seq_len(), 12);
|
||||||
|
assert_eq!(cache.offset(), 0);
|
||||||
|
|
||||||
|
let t = Tensor::new(&[12.], &Device::Cpu)?;
|
||||||
|
let data = cache.append(&t)?;
|
||||||
|
assert_eq!(data.to_vec1::<f64>()?, [12., 7., 8., 9., 10., 11.]);
|
||||||
|
assert_eq!(cache.current_seq_len(), 13);
|
||||||
|
assert_eq!(cache.offset(), 1);
|
||||||
|
|
||||||
|
let t = Tensor::new(&[0., 1., 2., 3., 4., 5., 6., 7., 8.], &Device::Cpu)?;
|
||||||
|
let data = cache.append(&t)?;
|
||||||
|
assert_eq!(data.to_vec1::<f64>()?, [0., 1., 2., 3., 4., 5., 6., 7., 8.]);
|
||||||
|
assert_eq!(cache.current_seq_len(), 22);
|
||||||
|
assert_eq!(cache.offset(), 0);
|
||||||
|
|
||||||
|
let t = Tensor::new(&[42.], &Device::Cpu)?;
|
||||||
|
let data = cache.append(&t)?;
|
||||||
|
assert_eq!(data.to_vec1::<f64>()?, [42., 4., 5., 6., 7., 8.]);
|
||||||
|
assert_eq!(cache.current_seq_len(), 23);
|
||||||
|
assert_eq!(cache.offset(), 1);
|
||||||
|
|
||||||
|
cache.reset();
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -127,7 +127,7 @@ pub struct StreamingMultiheadAttention {
|
|||||||
context: usize,
|
context: usize,
|
||||||
neg_inf: Tensor,
|
neg_inf: Tensor,
|
||||||
rope: Option<Arc<RotaryEmbedding>>,
|
rope: Option<Arc<RotaryEmbedding>>,
|
||||||
kv_cache: candle_nn::kv_cache::KvCache,
|
kv_cache: candle_nn::kv_cache::RotatingKvCache,
|
||||||
pos: usize,
|
pos: usize,
|
||||||
use_flash_attn: bool,
|
use_flash_attn: bool,
|
||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
@ -153,7 +153,7 @@ impl StreamingMultiheadAttention {
|
|||||||
num_heads: cfg.num_heads,
|
num_heads: cfg.num_heads,
|
||||||
context: cfg.context,
|
context: cfg.context,
|
||||||
neg_inf,
|
neg_inf,
|
||||||
kv_cache: candle_nn::kv_cache::KvCache::new(2, cfg.max_seq_len),
|
kv_cache: candle_nn::kv_cache::RotatingKvCache::new(2, cfg.context),
|
||||||
pos: 0,
|
pos: 0,
|
||||||
use_flash_attn: false,
|
use_flash_attn: false,
|
||||||
span: tracing::span!(tracing::Level::TRACE, "mha"),
|
span: tracing::span!(tracing::Level::TRACE, "mha"),
|
||||||
@ -236,7 +236,7 @@ impl StreamingMultiheadAttention {
|
|||||||
self.kv_cache.reset()
|
self.kv_cache.reset()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::KvCache) {
|
pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::RotatingKvCache) {
|
||||||
self.kv_cache = kv_cache
|
self.kv_cache = kv_cache
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -582,7 +582,7 @@ impl StreamingTransformerLayer {
|
|||||||
self.self_attn.reset_kv_cache()
|
self.self_attn.reset_kv_cache()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::KvCache) {
|
pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::RotatingKvCache) {
|
||||||
self.self_attn.set_kv_cache(kv_cache)
|
self.self_attn.set_kv_cache(kv_cache)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -629,18 +629,22 @@ impl StreamingTransformer {
|
|||||||
|
|
||||||
pub fn forward_ca(&mut self, xs: &Tensor, ca_src: Option<&Tensor>) -> Result<Tensor> {
|
pub fn forward_ca(&mut self, xs: &Tensor, ca_src: Option<&Tensor>) -> Result<Tensor> {
|
||||||
let (_b, t, c) = xs.dims3()?;
|
let (_b, t, c) = xs.dims3()?;
|
||||||
// We will extract at most "context" from the kv_cache.
|
|
||||||
// Note that the mask will discard the values that are before context.
|
|
||||||
let pos = self.layers[0]
|
let pos = self.layers[0]
|
||||||
.self_attn
|
.self_attn
|
||||||
.kv_cache
|
.kv_cache
|
||||||
.k_cache()
|
.k_cache()
|
||||||
.current_seq_len()
|
.current_seq_len();
|
||||||
.min(self.context);
|
|
||||||
let mask = if t == 1 {
|
let mask = if t == 1 {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
Some(get_mask(t, pos + t, self.context, xs.device())?)
|
let cache_out_len = if t < self.context {
|
||||||
|
(pos + t).min(self.context)
|
||||||
|
} else {
|
||||||
|
t
|
||||||
|
};
|
||||||
|
// TODO: this is wrong, the mask depends on the kv-cache offset because of its rotating
|
||||||
|
// nature.
|
||||||
|
Some(get_mask(t, cache_out_len, self.context, xs.device())?)
|
||||||
};
|
};
|
||||||
let mut xs = match self.positional_embedding {
|
let mut xs = match self.positional_embedding {
|
||||||
PositionalEmbedding::Rope | PositionalEmbedding::None => xs.clone(),
|
PositionalEmbedding::Rope | PositionalEmbedding::None => xs.clone(),
|
||||||
|
Reference in New Issue
Block a user