mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Encodec encoding demo. (#1775)
This commit is contained in:
@ -29,6 +29,10 @@ struct Args {
|
|||||||
/// Output file that will be generated in wav format.
|
/// Output file that will be generated in wav format.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
out: String,
|
out: String,
|
||||||
|
|
||||||
|
/// Do another step of encoding the PCM data and and decoding the resulting codes.
|
||||||
|
#[arg(long)]
|
||||||
|
roundtrip: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
@ -48,8 +52,19 @@ fn main() -> Result<()> {
|
|||||||
let codes = codes.get("codes").expect("no codes in input file").i(0)?;
|
let codes = codes.get("codes").expect("no codes in input file").i(0)?;
|
||||||
println!("codes shape: {:?}", codes.shape());
|
println!("codes shape: {:?}", codes.shape());
|
||||||
let pcm = model.decode(&codes)?;
|
let pcm = model.decode(&codes)?;
|
||||||
let pcm = pcm.i(0)?.i(0)?.to_vec1::<f32>()?;
|
println!("pcm shape: {:?}", pcm.shape());
|
||||||
|
|
||||||
|
let pcm = if args.roundtrip {
|
||||||
|
let codes = model.encode(&pcm)?;
|
||||||
|
println!("second step codes shape: {:?}", pcm.shape());
|
||||||
|
let pcm = model.decode(&codes)?;
|
||||||
|
println!("second step pcm shape: {:?}", pcm.shape());
|
||||||
|
pcm
|
||||||
|
} else {
|
||||||
|
pcm
|
||||||
|
};
|
||||||
|
|
||||||
|
let pcm = pcm.i(0)?.i(0)?.to_vec1::<f32>()?;
|
||||||
let mut output = std::fs::File::create(&args.out)?;
|
let mut output = std::fs::File::create(&args.out)?;
|
||||||
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
|
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
|
||||||
|
|
||||||
|
@ -283,7 +283,8 @@ impl VectorQuantization {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
|
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
self.codebook.encode_slow(xs)
|
let xs = xs.transpose(1, 2)?;
|
||||||
|
self.codebook.encode_slow(&xs)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
|
pub fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
|
||||||
|
Reference in New Issue
Block a user