Adding auto download of audio file.

This commit is contained in:
Nicolas Patry
2023-07-05 15:20:33 +00:00
parent e85573a4bd
commit 653c5049f8

View File

@ -172,9 +172,9 @@ struct Args {
#[arg(long)]
revision: Option<String>,
/// The input to be processed, in wav formats.
#[arg(long, default_value = "jfk.wav")]
input: String,
/// The input to be processed, in wav formats, will default to `jfk.wav` https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav
#[arg(long)]
input: Option<String>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
@ -208,14 +208,19 @@ async fn main() -> Result<()> {
(None, None) => (default_model, default_revision),
};
let (config_filename, tokenizer_filename, weights_filename) = if path.exists() {
let (config_filename, tokenizer_filename, weights_filename, input) = if path.exists() {
let mut config_filename = path.clone();
config_filename.push("config.json");
let mut tokenizer_filename = path.clone();
tokenizer_filename.push("tokenizer.json");
let mut model_filename = path.clone();
model_filename.push("model.safetensors");
(config_filename, tokenizer_filename, model_filename)
(
config_filename,
tokenizer_filename,
model_filename,
std::path::PathBuf::from(args.input.expect("You didn't specify a file to read from yet, are using a local model, please add `--input example.wav` to read some audio file")),
)
} else {
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let api = Api::new()?;
@ -223,9 +228,18 @@ async fn main() -> Result<()> {
api.get(&repo, "config.json").await?,
api.get(&repo, "tokenizer.json").await?,
api.get(&repo, "model.safetensors").await?,
if let Some(input) = args.input {
std::path::PathBuf::from(input)
} else {
println!("No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav");
api.get(
&Repo::new("Narsil/candle-examples".to_string(), RepoType::Dataset),
"samples_jfk.wav",
)
.await?
},
)
};
println!("Weights {weights_filename:?}");
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let mel_filters = unsafe { candle::safetensors::MmapedFile::new(args.filters)? };
@ -234,7 +248,7 @@ async fn main() -> Result<()> {
println!("loaded mel filters {:?}", mel_filters.shape());
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
let mut input = std::fs::File::open(args.input)?;
let mut input = std::fs::File::open(input)?;
let (header, data) = wav::read(&mut input)?;
println!("loaded wav data: {header:?}");
if header.sampling_rate != SAMPLE_RATE as u32 {
@ -265,6 +279,7 @@ async fn main() -> Result<()> {
let (_, _, content_frames) = mel.shape().r3()?;
let mut seek = 0;
let mut segments = vec![];
let start = std::time::Instant::now();
while seek < content_frames {
let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
let segment_size = usize::min(content_frames - seek, N_FRAMES);
@ -281,7 +296,7 @@ async fn main() -> Result<()> {
duration: segment_duration,
dr,
};
println!("{seek}: {segment:?}");
println!("{seek}: {segment:?} : Took {:?}", start.elapsed());
segments.push(segment)
}
Ok(())