mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add a RotatingKVCache. (#2493)
* Add a RotatingKVCache. * Add some KvCache tests. * Test the reset too. * More kv-cache testing. * More tests for the rotating kv-cache. * Improve the api for the rotating cache so that the whole src tensor gets returned when it's overlarge. * Handle contiguity + bugfix + use in mimi. * Add a way to test the mimi streaming mode. * Mimi streaming fixes. * More rotating kv-cache. * Fix the attn mask generation. * Handle the abs case. * Add some tests for the generated mask.
This commit is contained in:
@ -39,6 +39,11 @@ struct Args {
|
||||
/// The model weight file, in safetensor format.
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
/// Whether to use streaming or not, when streaming slices of data of the given size are passed
|
||||
/// to the encoder/decoder one at a time.
|
||||
#[arg(long)]
|
||||
streaming: Option<usize>,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -87,20 +92,49 @@ fn main() -> Result<()> {
|
||||
pcm
|
||||
}
|
||||
};
|
||||
let pcm_len = pcm.len();
|
||||
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
|
||||
println!("input pcm shape: {:?}", pcm.shape());
|
||||
model.encode(&pcm)?
|
||||
match args.streaming {
|
||||
Some(chunk_size) => {
|
||||
let mut code_chunks = vec![];
|
||||
for pcm in pcm.chunks(chunk_size) {
|
||||
let pcm = Tensor::new(pcm, &device)?.reshape((1, 1, ()))?;
|
||||
let code_chunk = model.encode(&pcm)?;
|
||||
code_chunks.push(code_chunk)
|
||||
}
|
||||
Tensor::cat(&code_chunks, candle::D::Minus1)?
|
||||
}
|
||||
None => {
|
||||
let pcm_len = pcm.len();
|
||||
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
|
||||
println!("input pcm shape: {:?}", pcm.shape());
|
||||
model.encode(&pcm)?
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
println!("codes shape: {:?}", codes.shape());
|
||||
model.reset_state();
|
||||
|
||||
match args.action {
|
||||
Action::AudioToCode => {
|
||||
codes.save_safetensors("codes", &args.out_file)?;
|
||||
}
|
||||
Action::AudioToAudio | Action::CodeToAudio => {
|
||||
let pcm = model.decode(&codes)?;
|
||||
let pcm = match args.streaming {
|
||||
Some(chunk_size) => {
|
||||
let seq_len = codes.dim(candle::D::Minus1)?;
|
||||
let mut pcm_chunks = vec![];
|
||||
for chunk_start in (0..seq_len).step_by(chunk_size) {
|
||||
let chunk_len = usize::min(chunk_size, seq_len - chunk_start);
|
||||
let codes = codes.narrow(candle::D::Minus1, chunk_start, chunk_len)?;
|
||||
let pcm = model.decode_step(&codes.into())?;
|
||||
if let Some(pcm) = pcm.as_option() {
|
||||
pcm_chunks.push(pcm.clone())
|
||||
}
|
||||
}
|
||||
Tensor::cat(&pcm_chunks, candle::D::Minus1)?
|
||||
}
|
||||
None => model.decode(&codes)?,
|
||||
};
|
||||
println!("output pcm shape: {:?}", pcm.shape());
|
||||
let pcm = pcm.i(0)?.i(0)?;
|
||||
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
|
||||
|
Reference in New Issue
Block a user