mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Adding auto download of audio file.
This commit is contained in:
@ -172,9 +172,9 @@ struct Args {
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
/// The input to be processed, in wav formats.
|
||||
#[arg(long, default_value = "jfk.wav")]
|
||||
input: String,
|
||||
/// The input to be processed, in wav formats, will default to `jfk.wav` https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav
|
||||
#[arg(long)]
|
||||
input: Option<String>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
@ -208,14 +208,19 @@ async fn main() -> Result<()> {
|
||||
(None, None) => (default_model, default_revision),
|
||||
};
|
||||
|
||||
let (config_filename, tokenizer_filename, weights_filename) = if path.exists() {
|
||||
let (config_filename, tokenizer_filename, weights_filename, input) = if path.exists() {
|
||||
let mut config_filename = path.clone();
|
||||
config_filename.push("config.json");
|
||||
let mut tokenizer_filename = path.clone();
|
||||
tokenizer_filename.push("tokenizer.json");
|
||||
let mut model_filename = path.clone();
|
||||
model_filename.push("model.safetensors");
|
||||
(config_filename, tokenizer_filename, model_filename)
|
||||
(
|
||||
config_filename,
|
||||
tokenizer_filename,
|
||||
model_filename,
|
||||
std::path::PathBuf::from(args.input.expect("You didn't specify a file to read from yet, are using a local model, please add `--input example.wav` to read some audio file")),
|
||||
)
|
||||
} else {
|
||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||
let api = Api::new()?;
|
||||
@ -223,9 +228,18 @@ async fn main() -> Result<()> {
|
||||
api.get(&repo, "config.json").await?,
|
||||
api.get(&repo, "tokenizer.json").await?,
|
||||
api.get(&repo, "model.safetensors").await?,
|
||||
if let Some(input) = args.input {
|
||||
std::path::PathBuf::from(input)
|
||||
} else {
|
||||
println!("No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav");
|
||||
api.get(
|
||||
&Repo::new("Narsil/candle-examples".to_string(), RepoType::Dataset),
|
||||
"samples_jfk.wav",
|
||||
)
|
||||
.await?
|
||||
},
|
||||
)
|
||||
};
|
||||
println!("Weights {weights_filename:?}");
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let mel_filters = unsafe { candle::safetensors::MmapedFile::new(args.filters)? };
|
||||
@ -234,7 +248,7 @@ async fn main() -> Result<()> {
|
||||
println!("loaded mel filters {:?}", mel_filters.shape());
|
||||
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
|
||||
|
||||
let mut input = std::fs::File::open(args.input)?;
|
||||
let mut input = std::fs::File::open(input)?;
|
||||
let (header, data) = wav::read(&mut input)?;
|
||||
println!("loaded wav data: {header:?}");
|
||||
if header.sampling_rate != SAMPLE_RATE as u32 {
|
||||
@ -265,6 +279,7 @@ async fn main() -> Result<()> {
|
||||
let (_, _, content_frames) = mel.shape().r3()?;
|
||||
let mut seek = 0;
|
||||
let mut segments = vec![];
|
||||
let start = std::time::Instant::now();
|
||||
while seek < content_frames {
|
||||
let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
|
||||
let segment_size = usize::min(content_frames - seek, N_FRAMES);
|
||||
@ -281,7 +296,7 @@ async fn main() -> Result<()> {
|
||||
duration: segment_duration,
|
||||
dr,
|
||||
};
|
||||
println!("{seek}: {segment:?}");
|
||||
println!("{seek}: {segment:?} : Took {:?}", start.elapsed());
|
||||
segments.push(segment)
|
||||
}
|
||||
Ok(())
|
||||
|
Reference in New Issue
Block a user