diff --git a/candle-examples/examples/encodec/main.rs b/candle-examples/examples/encodec/main.rs index 47a9ba59..42f2b3f9 100644 --- a/candle-examples/examples/encodec/main.rs +++ b/candle-examples/examples/encodec/main.rs @@ -29,6 +29,10 @@ struct Args { /// Output file that will be generated in wav format. #[arg(long)] out: String, + + /// Do another step of encoding the PCM data and and decoding the resulting codes. + #[arg(long)] + roundtrip: bool, } fn main() -> Result<()> { @@ -48,8 +52,19 @@ fn main() -> Result<()> { let codes = codes.get("codes").expect("no codes in input file").i(0)?; println!("codes shape: {:?}", codes.shape()); let pcm = model.decode(&codes)?; - let pcm = pcm.i(0)?.i(0)?.to_vec1::()?; + println!("pcm shape: {:?}", pcm.shape()); + let pcm = if args.roundtrip { + let codes = model.encode(&pcm)?; + println!("second step codes shape: {:?}", pcm.shape()); + let pcm = model.decode(&codes)?; + println!("second step pcm shape: {:?}", pcm.shape()); + pcm + } else { + pcm + }; + + let pcm = pcm.i(0)?.i(0)?.to_vec1::()?; let mut output = std::fs::File::create(&args.out)?; candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?; diff --git a/candle-transformers/src/models/encodec.rs b/candle-transformers/src/models/encodec.rs index d3b26e1e..1316098b 100644 --- a/candle-transformers/src/models/encodec.rs +++ b/candle-transformers/src/models/encodec.rs @@ -283,7 +283,8 @@ impl VectorQuantization { } pub fn encode(&self, xs: &Tensor) -> Result { - self.codebook.encode_slow(xs) + let xs = xs.transpose(1, 2)?; + self.codebook.encode_slow(&xs) } pub fn decode(&self, embed_ind: &Tensor) -> Result {