diff --git a/Cargo.toml b/Cargo.toml index 0dec835b..05c6240b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,7 @@ clap = { version = "4.2.4", features = ["derive"] } cudarc = { git = "https://github.com/LaurentMazare/cudarc.git", branch = "cublas-bf16", features = ["f16"] } # TODO: Switch back to the official gemm implementation if we manage to upstream the changes. gemm = { git = "https://github.com/LaurentMazare/gemm.git" } -hf-hub = "0.1.3" +hf-hub = "0.2.0" half = { version = "2.3.1", features = ["num-traits", "rand_distr"] } intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] } libc = { version = "0.2.147" } diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 6672ad09..79c78968 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -69,10 +69,11 @@ impl Args { ) } else { let api = Api::new()?; + let api = api.repo(repo); ( - api.get(&repo, "config.json")?, - api.get(&repo, "tokenizer.json")?, - api.get(&repo, "model.safetensors")?, + api.get("config.json")?, + api.get("tokenizer.json")?, + api.get("model.safetensors")?, ) }; let config = std::fs::read_to_string(config_filename)?; diff --git a/candle-examples/examples/falcon/main.rs b/candle-examples/examples/falcon/main.rs index 3a284c86..a01191a5 100644 --- a/candle-examples/examples/falcon/main.rs +++ b/candle-examples/examples/falcon/main.rs @@ -123,14 +123,18 @@ fn main() -> Result<()> { let device = candle_examples::device(args.cpu)?; let start = std::time::Instant::now(); let api = Api::new()?; - let repo = Repo::with_revision(args.model_id, RepoType::Model, args.revision); - let tokenizer_filename = api.get(&repo, "tokenizer.json")?; + let repo = api.repo(Repo::with_revision( + args.model_id, + RepoType::Model, + args.revision, + )); + let tokenizer_filename = repo.get("tokenizer.json")?; let mut filenames = vec![]; for rfilename in [ "model-00001-of-00002.safetensors", "model-00002-of-00002.safetensors", ] { - let filename = api.get(&repo, rfilename)?; + let filename = repo.get(rfilename)?; filenames.push(filename); } println!("retrieved the files in {:?}", start.elapsed()); diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index 582ac3f8..d9d1e21a 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -18,7 +18,7 @@ use clap::Parser; use candle::{DType, Tensor}; use candle_nn::VarBuilder; use candle_transformers::generation::LogitsProcessor; -use hf_hub::{api::sync::Api, Repo, RepoType}; +use hf_hub::api::sync::Api; mod model; use model::{Config, Llama}; @@ -146,14 +146,14 @@ fn main() -> Result<()> { } }); println!("loading the model weights from {model_id}"); - let repo = Repo::new(model_id, RepoType::Model); - let tokenizer_filename = api.get(&repo, "tokenizer.json")?; + let api = api.model(model_id); + let tokenizer_filename = api.get("tokenizer.json")?; let mut filenames = vec![]; for rfilename in [ "model-00001-of-00002.safetensors", "model-00002-of-00002.safetensors", ] { - let filename = api.get(&repo, rfilename)?; + let filename = api.get(rfilename)?; filenames.push(filename); } diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 079424e3..c03779e7 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -282,28 +282,23 @@ fn main() -> Result<()> { std::path::PathBuf::from(args.input.expect("You didn't specify a file to read from yet, are using a local model, please add `--input example.wav` to read some audio file")), ) } else { - let repo = Repo::with_revision(model_id, RepoType::Model, revision); let api = Api::new()?; + let dataset = api.dataset("Narsil/candle-examples".to_string()); + let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision)); let sample = if let Some(input) = args.input { if let Some(sample) = input.strip_prefix("sample:") { - api.get( - &Repo::new("Narsil/candle-examples".to_string(), RepoType::Dataset), - &format!("samples_{sample}.wav"), - )? + dataset.get(&format!("samples_{sample}.wav"))? } else { std::path::PathBuf::from(input) } } else { println!("No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav"); - api.get( - &Repo::new("Narsil/candle-examples".to_string(), RepoType::Dataset), - "samples_jfk.wav", - )? + dataset.get("samples_jfk.wav")? }; ( - api.get(&repo, "config.json")?, - api.get(&repo, "tokenizer.json")?, - api.get(&repo, "model.safetensors")?, + repo.get("config.json")?, + repo.get("tokenizer.json")?, + repo.get("model.safetensors")?, sample, ) };