mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Merge pull request #78 from LaurentMazare/whisper_update
Adapting whisper for Hub use.
This commit is contained in:
@ -7,6 +7,7 @@
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_hub::{api::Api, Repo, RepoType};
|
||||
use clap::Parser;
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
use tokenizers::Tokenizer;
|
||||
@ -68,7 +69,7 @@ impl Decode {
|
||||
let model = &self.model;
|
||||
let audio_features = model.encoder.forward(mel)?;
|
||||
println!("audio features: {:?}", audio_features.dims());
|
||||
let sample_len = model.config.n_text_ctx / 2;
|
||||
let sample_len = model.config.max_target_positions / 2;
|
||||
let mut sum_logprob = 0f64;
|
||||
let mut no_speech_prob = f64::NAN;
|
||||
let mut tokens = vec![SOT_TOKEN];
|
||||
@ -112,7 +113,7 @@ impl Decode {
|
||||
.softmax(candle::D::Minus1)?
|
||||
.get(next_token as usize)?
|
||||
.to_scalar::<f32>()? as f64;
|
||||
if next_token == EOT_TOKEN || tokens.len() > model.config.n_text_ctx {
|
||||
if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions {
|
||||
break;
|
||||
}
|
||||
sum_logprob += prob.ln();
|
||||
@ -165,14 +166,15 @@ struct Args {
|
||||
cpu: bool,
|
||||
|
||||
#[arg(long)]
|
||||
weights: String,
|
||||
model_id: Option<String>,
|
||||
|
||||
/// The input to be processed, in wav formats.
|
||||
/// The model to use, check out available models: https://huggingface.co/models?search=whisper
|
||||
#[arg(long)]
|
||||
input: String,
|
||||
revision: Option<String>,
|
||||
|
||||
/// The input to be processed, in wav formats, will default to `jfk.wav` https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav
|
||||
#[arg(long)]
|
||||
tokenizer_config: String,
|
||||
input: Option<String>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
@ -186,7 +188,8 @@ struct Args {
|
||||
filters: String,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
let device = if args.cpu {
|
||||
Device::Cpu
|
||||
@ -195,7 +198,49 @@ fn main() -> Result<()> {
|
||||
};
|
||||
let rng = rand::rngs::StdRng::seed_from_u64(args.seed);
|
||||
|
||||
let tokenizer = Tokenizer::from_file(args.tokenizer_config).map_err(E::msg)?;
|
||||
let default_model = "openai/whisper-tiny.en".to_string();
|
||||
let path = std::path::PathBuf::from(default_model.clone());
|
||||
let default_revision = "refs/pr/15".to_string();
|
||||
let (model_id, revision) = match (args.model_id, args.revision) {
|
||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||
(None, Some(revision)) => (default_model, revision),
|
||||
(None, None) => (default_model, default_revision),
|
||||
};
|
||||
|
||||
let (config_filename, tokenizer_filename, weights_filename, input) = if path.exists() {
|
||||
let mut config_filename = path.clone();
|
||||
config_filename.push("config.json");
|
||||
let mut tokenizer_filename = path.clone();
|
||||
tokenizer_filename.push("tokenizer.json");
|
||||
let mut model_filename = path.clone();
|
||||
model_filename.push("model.safetensors");
|
||||
(
|
||||
config_filename,
|
||||
tokenizer_filename,
|
||||
model_filename,
|
||||
std::path::PathBuf::from(args.input.expect("You didn't specify a file to read from yet, are using a local model, please add `--input example.wav` to read some audio file")),
|
||||
)
|
||||
} else {
|
||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||
let api = Api::new()?;
|
||||
(
|
||||
api.get(&repo, "config.json").await?,
|
||||
api.get(&repo, "tokenizer.json").await?,
|
||||
api.get(&repo, "model.safetensors").await?,
|
||||
if let Some(input) = args.input {
|
||||
std::path::PathBuf::from(input)
|
||||
} else {
|
||||
println!("No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav");
|
||||
api.get(
|
||||
&Repo::new("Narsil/candle-examples".to_string(), RepoType::Dataset),
|
||||
"samples_jfk.wav",
|
||||
)
|
||||
.await?
|
||||
},
|
||||
)
|
||||
};
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let mel_filters = unsafe { candle::safetensors::MmapedFile::new(args.filters)? };
|
||||
let mel_filters = mel_filters.deserialize()?;
|
||||
@ -203,7 +248,7 @@ fn main() -> Result<()> {
|
||||
println!("loaded mel filters {:?}", mel_filters.shape());
|
||||
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
|
||||
|
||||
let mut input = std::fs::File::open(args.input)?;
|
||||
let mut input = std::fs::File::open(input)?;
|
||||
let (header, data) = wav::read(&mut input)?;
|
||||
println!("loaded wav data: {header:?}");
|
||||
if header.sampling_rate != SAMPLE_RATE as u32 {
|
||||
@ -220,10 +265,11 @@ fn main() -> Result<()> {
|
||||
let mel = Tensor::from_vec(mel, (1, N_MELS, mel_len / N_MELS), &device)?;
|
||||
println!("loaded mel: {:?}", mel.dims());
|
||||
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device);
|
||||
let model = Whisper::load(&vb, Config::tiny_en())?;
|
||||
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
||||
let model = Whisper::load(&vb, config)?;
|
||||
let mut dc = Decode {
|
||||
model,
|
||||
rng,
|
||||
@ -233,6 +279,7 @@ fn main() -> Result<()> {
|
||||
let (_, _, content_frames) = mel.shape().r3()?;
|
||||
let mut seek = 0;
|
||||
let mut segments = vec![];
|
||||
let start = std::time::Instant::now();
|
||||
while seek < content_frames {
|
||||
let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
|
||||
let segment_size = usize::min(content_frames - seek, N_FRAMES);
|
||||
@ -249,7 +296,7 @@ fn main() -> Result<()> {
|
||||
duration: segment_duration,
|
||||
dr,
|
||||
};
|
||||
println!("{seek}: {segment:?}");
|
||||
println!("{seek}: {segment:?} : Took {:?}", start.elapsed());
|
||||
segments.push(segment)
|
||||
}
|
||||
Ok(())
|
||||
|
Reference in New Issue
Block a user