diff --git a/candle-examples/examples/mimi/main.rs b/candle-examples/examples/mimi/main.rs index cfc1a553..788b3fd9 100644 --- a/candle-examples/examples/mimi/main.rs +++ b/candle-examples/examples/mimi/main.rs @@ -39,6 +39,11 @@ struct Args { /// The model weight file, in safetensor format. #[arg(long)] model: Option, + + /// Whether to use streaming or not, when streaming slices of data of the given size are passed + /// to the encoder/decoder one at a time. + #[arg(long)] + streaming: Option, } fn main() -> Result<()> { @@ -87,20 +92,46 @@ fn main() -> Result<()> { pcm } }; - let pcm_len = pcm.len(); - let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?; - println!("input pcm shape: {:?}", pcm.shape()); - model.encode(&pcm)? + match args.streaming { + Some(chunk_size) => { + let mut code_chunks = vec![]; + for pcm in pcm.chunks(chunk_size) { + let pcm = Tensor::new(pcm, &device)?.reshape((1, 1, ()))?; + let code_chunk = model.encode(&pcm)?; + code_chunks.push(code_chunk) + } + Tensor::cat(&code_chunks, candle::D::Minus1)? + } + None => { + let pcm_len = pcm.len(); + let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?; + println!("input pcm shape: {:?}", pcm.shape()); + model.encode(&pcm)? + } + } } }; println!("codes shape: {:?}", codes.shape()); + model.reset_state(); match args.action { Action::AudioToCode => { codes.save_safetensors("codes", &args.out_file)?; } Action::AudioToAudio | Action::CodeToAudio => { - let pcm = model.decode(&codes)?; + let pcm = match args.streaming { + Some(chunk_size) => { + let seq_len = codes.dim(candle::D::Minus1)?; + let mut pcm_chunks = vec![]; + for chunk_start in (0..seq_len).step_by(chunk_size) { + let chunk_len = usize::min(chunk_size, seq_len - chunk_start); + let codes = codes.narrow(candle::D::Minus1, chunk_start, chunk_len)?; + pcm_chunks.push(model.decode(&codes)?) + } + Tensor::cat(&pcm_chunks, candle::D::Minus1)? + } + None => model.decode(&codes)?, + }; println!("output pcm shape: {:?}", pcm.shape()); let pcm = pcm.i(0)?.i(0)?; let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;