mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Upgrading hf-hub to 0.2.0
(Modified API to not pass the Repo around
all the time)
This commit is contained in:
@ -22,7 +22,7 @@ clap = { version = "4.2.4", features = ["derive"] }
|
||||
cudarc = { git = "https://github.com/LaurentMazare/cudarc.git", branch = "cublas-bf16", features = ["f16"] }
|
||||
# TODO: Switch back to the official gemm implementation if we manage to upstream the changes.
|
||||
gemm = { git = "https://github.com/LaurentMazare/gemm.git" }
|
||||
hf-hub = "0.1.3"
|
||||
hf-hub = "0.2.0"
|
||||
half = { version = "2.3.1", features = ["num-traits", "rand_distr"] }
|
||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] }
|
||||
libc = { version = "0.2.147" }
|
||||
|
@ -69,10 +69,11 @@ impl Args {
|
||||
)
|
||||
} else {
|
||||
let api = Api::new()?;
|
||||
let api = api.repo(repo);
|
||||
(
|
||||
api.get(&repo, "config.json")?,
|
||||
api.get(&repo, "tokenizer.json")?,
|
||||
api.get(&repo, "model.safetensors")?,
|
||||
api.get("config.json")?,
|
||||
api.get("tokenizer.json")?,
|
||||
api.get("model.safetensors")?,
|
||||
)
|
||||
};
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
|
@ -123,14 +123,18 @@ fn main() -> Result<()> {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let repo = Repo::with_revision(args.model_id, RepoType::Model, args.revision);
|
||||
let tokenizer_filename = api.get(&repo, "tokenizer.json")?;
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
args.model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let tokenizer_filename = repo.get("tokenizer.json")?;
|
||||
let mut filenames = vec![];
|
||||
for rfilename in [
|
||||
"model-00001-of-00002.safetensors",
|
||||
"model-00002-of-00002.safetensors",
|
||||
] {
|
||||
let filename = api.get(&repo, rfilename)?;
|
||||
let filename = repo.get(rfilename)?;
|
||||
filenames.push(filename);
|
||||
}
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
|
@ -18,7 +18,7 @@ use clap::Parser;
|
||||
use candle::{DType, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use hf_hub::api::sync::Api;
|
||||
|
||||
mod model;
|
||||
use model::{Config, Llama};
|
||||
@ -146,14 +146,14 @@ fn main() -> Result<()> {
|
||||
}
|
||||
});
|
||||
println!("loading the model weights from {model_id}");
|
||||
let repo = Repo::new(model_id, RepoType::Model);
|
||||
let tokenizer_filename = api.get(&repo, "tokenizer.json")?;
|
||||
let api = api.model(model_id);
|
||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||
let mut filenames = vec![];
|
||||
for rfilename in [
|
||||
"model-00001-of-00002.safetensors",
|
||||
"model-00002-of-00002.safetensors",
|
||||
] {
|
||||
let filename = api.get(&repo, rfilename)?;
|
||||
let filename = api.get(rfilename)?;
|
||||
filenames.push(filename);
|
||||
}
|
||||
|
||||
|
@ -282,28 +282,23 @@ fn main() -> Result<()> {
|
||||
std::path::PathBuf::from(args.input.expect("You didn't specify a file to read from yet, are using a local model, please add `--input example.wav` to read some audio file")),
|
||||
)
|
||||
} else {
|
||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||
let api = Api::new()?;
|
||||
let dataset = api.dataset("Narsil/candle-examples".to_string());
|
||||
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||
let sample = if let Some(input) = args.input {
|
||||
if let Some(sample) = input.strip_prefix("sample:") {
|
||||
api.get(
|
||||
&Repo::new("Narsil/candle-examples".to_string(), RepoType::Dataset),
|
||||
&format!("samples_{sample}.wav"),
|
||||
)?
|
||||
dataset.get(&format!("samples_{sample}.wav"))?
|
||||
} else {
|
||||
std::path::PathBuf::from(input)
|
||||
}
|
||||
} else {
|
||||
println!("No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav");
|
||||
api.get(
|
||||
&Repo::new("Narsil/candle-examples".to_string(), RepoType::Dataset),
|
||||
"samples_jfk.wav",
|
||||
)?
|
||||
dataset.get("samples_jfk.wav")?
|
||||
};
|
||||
(
|
||||
api.get(&repo, "config.json")?,
|
||||
api.get(&repo, "tokenizer.json")?,
|
||||
api.get(&repo, "model.safetensors")?,
|
||||
repo.get("config.json")?,
|
||||
repo.get("tokenizer.json")?,
|
||||
repo.get("model.safetensors")?,
|
||||
sample,
|
||||
)
|
||||
};
|
||||
|
Reference in New Issue
Block a user