Encodec encoding demo. (#1775)

This commit is contained in:
Laurent Mazare
2024-02-28 06:49:03 +01:00
committed by GitHub
parent 15e8644149
commit d0aca6c3c6
2 changed files with 18 additions and 2 deletions

View File

@ -29,6 +29,10 @@ struct Args {
/// Output file that will be generated in wav format. /// Output file that will be generated in wav format.
#[arg(long)] #[arg(long)]
out: String, out: String,
/// Do another step of encoding the PCM data and and decoding the resulting codes.
#[arg(long)]
roundtrip: bool,
} }
fn main() -> Result<()> { fn main() -> Result<()> {
@ -48,8 +52,19 @@ fn main() -> Result<()> {
let codes = codes.get("codes").expect("no codes in input file").i(0)?; let codes = codes.get("codes").expect("no codes in input file").i(0)?;
println!("codes shape: {:?}", codes.shape()); println!("codes shape: {:?}", codes.shape());
let pcm = model.decode(&codes)?; let pcm = model.decode(&codes)?;
let pcm = pcm.i(0)?.i(0)?.to_vec1::<f32>()?; println!("pcm shape: {:?}", pcm.shape());
let pcm = if args.roundtrip {
let codes = model.encode(&pcm)?;
println!("second step codes shape: {:?}", pcm.shape());
let pcm = model.decode(&codes)?;
println!("second step pcm shape: {:?}", pcm.shape());
pcm
} else {
pcm
};
let pcm = pcm.i(0)?.i(0)?.to_vec1::<f32>()?;
let mut output = std::fs::File::create(&args.out)?; let mut output = std::fs::File::create(&args.out)?;
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?; candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;

View File

@ -283,7 +283,8 @@ impl VectorQuantization {
} }
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> { pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
self.codebook.encode_slow(xs) let xs = xs.transpose(1, 2)?;
self.codebook.encode_slow(&xs)
} }
pub fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> { pub fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {