Read wav files.

This commit is contained in:
laurent
2023-07-05 11:53:58 +01:00
parent 26d1a7803f
commit 95f378ebb4
3 changed files with 61 additions and 41 deletions

View File

@ -40,33 +40,6 @@ const EOT_TOKEN: u32 = 50256;
const NO_SPEECH_TOKEN: u32 = 50361;
const NO_TIMESTAMP_TOKEN: u32 = 50362;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(long)]
weights: String,
#[arg(long)]
input: String,
#[arg(long)]
tokenizer_config: String,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
#[arg(
long,
default_value = "candle-examples/examples/whisper/mel_filters.safetensors"
)]
filters: String,
}
#[derive(Debug, Clone)]
struct DecodingResult {
tokens: Vec<u32>,
@ -184,6 +157,35 @@ impl Decode {
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(long)]
weights: String,
/// The input to be processed, in wav formats.
#[arg(long)]
input: String,
#[arg(long)]
tokenizer_config: String,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The mel filters in safetensors format.
#[arg(
long,
default_value = "candle-examples/examples/whisper/mel_filters.safetensors"
)]
filters: String,
}
fn main() -> Result<()> {
let args = Args::parse();
let device = if args.cpu {
@ -195,14 +197,28 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(args.tokenizer_config).map_err(E::msg)?;
let filters = unsafe { candle::safetensors::MmapedFile::new(args.filters)? };
let filters = filters.deserialize()?;
let filters = filters.tensor("mel_80", &device)?;
println!("loaded mel filters {:?}", filters.shape());
let mel_filters = unsafe { candle::safetensors::MmapedFile::new(args.filters)? };
let mel_filters = mel_filters.deserialize()?;
let mel_filters = mel_filters.tensor("mel_80", &device)?;
println!("loaded mel filters {:?}", mel_filters.shape());
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
let input = unsafe { candle::safetensors::MmapedFile::new(args.input)? };
let input = input.deserialize()?;
let mel = input.tensor("mel", &device)?;
let mut input = std::fs::File::open(args.input)?;
let (header, data) = wav::read(&mut input)?;
println!("loaded wav data: {header:?}");
if header.sampling_rate != audio::WHISPER_SAMPLE_RATE as u32 {
anyhow::bail!(
"wav file must have a {} sampling rate",
audio::WHISPER_SAMPLE_RATE
)
}
let data = data.as_sixteen().expect("expected 16 bit wav file");
let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
.iter()
.map(|v| *v as f32 / 32768.)
.collect();
let mel = audio::pcm_to_mel(&pcm_data, &mel_filters)?;
let mel = Tensor::new(&mel[..], &device)?;
println!("loaded mel: {:?}", mel.dims());
let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };