Make it easier to use whisper samples from the repo. (#112)

* Make it easier to use samples from the repo.

* Use f32 for accumulation in the f16/bf16 kernels.
This commit is contained in:
Laurent Mazare
2023-07-08 18:48:27 +01:00
committed by GitHub
parent eb64ad0d4d
commit c187f347bf
2 changed files with 34 additions and 25 deletions

View File

@ -197,8 +197,8 @@ impl Decoder {
let (_, _, content_frames) = mel.shape().r3()?;
let mut seek = 0;
let mut segments = vec![];
let start = std::time::Instant::now();
while seek < content_frames {
let start = std::time::Instant::now();
let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
let segment_size = usize::min(content_frames - seek, N_FRAMES);
let mel_segment = mel.narrow(2, seek, segment_size)?;
@ -214,7 +214,7 @@ impl Decoder {
duration: segment_duration,
dr,
};
println!("{seek}: {segment:?} : Took {:?}", start.elapsed());
println!("{seek}: {segment:?}, in {:?}", start.elapsed());
segments.push(segment)
}
Ok(segments)
@ -236,8 +236,9 @@ struct Args {
#[arg(long)]
revision: Option<String>,
/// The input to be processed, in wav formats, will default to `jfk.wav`
/// https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav
/// The input to be processed, in wav format, will default to `jfk.wav`. Alternatively
/// this can be set to sample:jfk, sample:gb1, ... to fetch a sample from the following
/// repo: https://huggingface.co/datasets/Narsil/candle_demo/
#[arg(long)]
input: Option<String>,
@ -286,19 +287,27 @@ fn main() -> Result<()> {
} else {
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let api = Api::new()?;
let sample = if let Some(input) = args.input {
if let Some(sample) = input.strip_prefix("sample:") {
api.get(
&Repo::new("Narsil/candle-examples".to_string(), RepoType::Dataset),
&format!("samples_{sample}.wav"),
)?
} else {
std::path::PathBuf::from(input)
}
} else {
println!("No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav");
api.get(
&Repo::new("Narsil/candle-examples".to_string(), RepoType::Dataset),
"samples_jfk.wav",
)?
};
(
api.get(&repo, "config.json")?,
api.get(&repo, "tokenizer.json")?,
api.get(&repo, "model.safetensors")?,
if let Some(input) = args.input {
std::path::PathBuf::from(input)
} else {
println!("No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav");
api.get(
&Repo::new("Narsil/candle-examples".to_string(), RepoType::Dataset),
"samples_jfk.wav",
)?
},
sample,
)
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;