mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Add a way to test the mimi streaming mode.
This commit is contained in:
@ -39,6 +39,11 @@ struct Args {
|
|||||||
/// The model weight file, in safetensor format.
|
/// The model weight file, in safetensor format.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model: Option<String>,
|
model: Option<String>,
|
||||||
|
|
||||||
|
/// Whether to use streaming or not, when streaming slices of data of the given size are passed
|
||||||
|
/// to the encoder/decoder one at a time.
|
||||||
|
#[arg(long)]
|
||||||
|
streaming: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
@ -87,20 +92,46 @@ fn main() -> Result<()> {
|
|||||||
pcm
|
pcm
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
match args.streaming {
|
||||||
|
Some(chunk_size) => {
|
||||||
|
let mut code_chunks = vec![];
|
||||||
|
for pcm in pcm.chunks(chunk_size) {
|
||||||
|
let pcm = Tensor::new(pcm, &device)?.reshape((1, 1, ()))?;
|
||||||
|
let code_chunk = model.encode(&pcm)?;
|
||||||
|
code_chunks.push(code_chunk)
|
||||||
|
}
|
||||||
|
Tensor::cat(&code_chunks, candle::D::Minus1)?
|
||||||
|
}
|
||||||
|
None => {
|
||||||
let pcm_len = pcm.len();
|
let pcm_len = pcm.len();
|
||||||
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
|
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
|
||||||
println!("input pcm shape: {:?}", pcm.shape());
|
println!("input pcm shape: {:?}", pcm.shape());
|
||||||
model.encode(&pcm)?
|
model.encode(&pcm)?
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
println!("codes shape: {:?}", codes.shape());
|
println!("codes shape: {:?}", codes.shape());
|
||||||
|
model.reset_state();
|
||||||
|
|
||||||
match args.action {
|
match args.action {
|
||||||
Action::AudioToCode => {
|
Action::AudioToCode => {
|
||||||
codes.save_safetensors("codes", &args.out_file)?;
|
codes.save_safetensors("codes", &args.out_file)?;
|
||||||
}
|
}
|
||||||
Action::AudioToAudio | Action::CodeToAudio => {
|
Action::AudioToAudio | Action::CodeToAudio => {
|
||||||
let pcm = model.decode(&codes)?;
|
let pcm = match args.streaming {
|
||||||
|
Some(chunk_size) => {
|
||||||
|
let seq_len = codes.dim(candle::D::Minus1)?;
|
||||||
|
let mut pcm_chunks = vec![];
|
||||||
|
for chunk_start in (0..seq_len).step_by(chunk_size) {
|
||||||
|
let chunk_len = usize::min(chunk_size, seq_len - chunk_start);
|
||||||
|
let codes = codes.narrow(candle::D::Minus1, chunk_start, chunk_len)?;
|
||||||
|
pcm_chunks.push(model.decode(&codes)?)
|
||||||
|
}
|
||||||
|
Tensor::cat(&pcm_chunks, candle::D::Minus1)?
|
||||||
|
}
|
||||||
|
None => model.decode(&codes)?,
|
||||||
|
};
|
||||||
println!("output pcm shape: {:?}", pcm.shape());
|
println!("output pcm shape: {:?}", pcm.shape());
|
||||||
let pcm = pcm.i(0)?.i(0)?;
|
let pcm = pcm.i(0)?.i(0)?;
|
||||||
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
|
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
|
||||||
|
Reference in New Issue
Block a user