From 439321745a0110223bd20d3b5cacb98f97aad6ce Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 19 Jul 2023 15:04:38 +0200 Subject: [PATCH] Removing `candle-hub` internal to extract into `hf-hub` standalone. --- Cargo.toml | 12 +- candle-examples/Cargo.toml | 2 +- candle-examples/examples/bert/main.rs | 2 +- candle-examples/examples/falcon/main.rs | 2 +- candle-examples/examples/llama/main.rs | 2 +- candle-examples/examples/whisper/main.rs | 2 +- candle-hub/.gitignore | 2 - candle-hub/Cargo.toml | 29 - candle-hub/src/api/mod.rs | 6 - candle-hub/src/api/sync.rs | 686 --------------------- candle-hub/src/api/tokio.rs | 723 ----------------------- candle-hub/src/lib.rs | 197 ------ candle-transformers/Cargo.toml | 2 +- 13 files changed, 9 insertions(+), 1658 deletions(-) delete mode 100644 candle-hub/.gitignore delete mode 100644 candle-hub/Cargo.toml delete mode 100644 candle-hub/src/api/mod.rs delete mode 100644 candle-hub/src/api/sync.rs delete mode 100644 candle-hub/src/api/tokio.rs delete mode 100644 candle-hub/src/lib.rs diff --git a/Cargo.toml b/Cargo.toml index 9c8b5682..6f435ba8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,6 @@ members = [ "candle-core", "candle-examples", - "candle-hub", "candle-nn", "candle-pyo3", "candle-transformers", @@ -19,29 +18,24 @@ clap = { version = "4.2.4", features = ["derive"] } # Re-enable this once 0.9.13 as been released as it would include the cublas-f16 changes # cudarc = { version = "0.9.13", optional = true, features = ["f16"] } cudarc = { git = "https://github.com/LaurentMazare/cudarc.git", branch = "cublas-bf16", features = ["f16"] } -futures = "0.3.28" # TODO: Switch back to the official gemm implementation once the following are available. # https://github.com/sarah-ek/gemm/pull/8. # https://github.com/sarah-ek/gemm/pull/9. gemm = { git = "https://github.com/LaurentMazare/gemm.git", branch = "f16-vec-plus-wasm-simd" } +hf-hub = "0.1.0" half = { version = "2.3.1", features = ["num-traits"] } -indicatif = "0.17.5" -intel-mkl-src = { version = "0.8.1", features = ["mkl-dynamic-lp64-iomp"] } +intel-mkl-src = { version = "0.8.1", features = ["mkl-static-ilp64-iomp"] } libc = { version = "0.2.147" } log = "0.4" memmap2 = "0.7.1" num_cpus = "1.15.0" num-traits = "0.2.15" rand = "0.8.5" -reqwest = "0.11.18" safetensors = "0.3.1" -serde = { version = "1.0.166", features = ["derive"] } +serde = { version = "1.0.171", features = ["derive"] } serde_json = "1.0.99" -sha256 = "=1.1.4" thiserror = "1" tokenizers = { version = "0.13.3", default-features = false } -tokio = "1.28.2" -tokio-test = "0.4.2" tracing = "0.1.37" tracing-chrome = "0.7.1" tracing-subscriber = "0.3.7" diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 114997b9..24435e81 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -21,7 +21,7 @@ intel-mkl-src = { workspace = true, optional = true } [dev-dependencies] anyhow = { workspace = true } -candle-hub = { path = "../candle-hub" } +hf-hub = { workspace = true} clap = { workspace = true } rand = { workspace = true } tokenizers = { workspace = true, features = ["onig"] } diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 8ef8b5ce..33f0a1fe 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -4,9 +4,9 @@ mod model; use anyhow::{anyhow, Error as E, Result}; use candle::Tensor; -use candle_hub::{api::sync::Api, Cache, Repo, RepoType}; use candle_nn::VarBuilder; use clap::Parser; +use hf_hub::{api::sync::Api, Cache, Repo, RepoType}; use model::{BertModel, Config, DTYPE}; use tokenizers::{PaddingParams, Tokenizer}; diff --git a/candle-examples/examples/falcon/main.rs b/candle-examples/examples/falcon/main.rs index 7d5eaa52..3a284c86 100644 --- a/candle-examples/examples/falcon/main.rs +++ b/candle-examples/examples/falcon/main.rs @@ -5,10 +5,10 @@ extern crate intel_mkl_src; use anyhow::{Error as E, Result}; use candle::{DType, Device, Tensor}; -use candle_hub::{api::sync::Api, Repo, RepoType}; use candle_nn::VarBuilder; use candle_transformers::generation::LogitsProcessor; use clap::Parser; +use hf_hub::{api::sync::Api, Repo, RepoType}; use tokenizers::Tokenizer; mod model; diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index aa02299d..40f1af06 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -16,9 +16,9 @@ use anyhow::{Error as E, Result}; use clap::Parser; use candle::{DType, Device, Tensor, D}; -use candle_hub::{api::sync::Api, Repo, RepoType}; use candle_nn::VarBuilder; use candle_transformers::generation::LogitsProcessor; +use hf_hub::{api::sync::Api, Repo, RepoType}; mod model; use model::{Config, Llama}; diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index d01fb605..bfd7f05c 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -11,9 +11,9 @@ extern crate intel_mkl_src; use anyhow::{Error as E, Result}; use candle::{DType, Device, Tensor}; -use candle_hub::{api::sync::Api, Repo, RepoType}; use candle_nn::VarBuilder; use clap::Parser; +use hf_hub::{api::sync::Api, Repo, RepoType}; use rand::{distributions::Distribution, SeedableRng}; use tokenizers::Tokenizer; diff --git a/candle-hub/.gitignore b/candle-hub/.gitignore deleted file mode 100644 index 4fffb2f8..00000000 --- a/candle-hub/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -/target -/Cargo.lock diff --git a/candle-hub/Cargo.toml b/candle-hub/Cargo.toml deleted file mode 100644 index 2b091642..00000000 --- a/candle-hub/Cargo.toml +++ /dev/null @@ -1,29 +0,0 @@ -[package] -name = "candle-hub" -version = "0.1.0" -edition = "2021" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -dirs = "5.0.1" -rand = { workspace = true } -thiserror = { workspace = true } -futures = { workspace = true, optional = true } -reqwest = { workspace = true, optional = true, features = ["json"] } -tokio = { workspace = true, features = ["fs"], optional = true } -serde = { workspace = true, optional = true } -serde_json = { workspace = true, optional = true } -indicatif = { workspace = true, optional = true } -num_cpus = { workspace = true, optional = true } - -[dev-dependencies] -rand = { workspace = true } -sha256 = { workspace = true } -tokio = { workspace = true, features = ["macros"] } -tokio-test = { workspace = true } - -[features] -default = ["online"] -online = ["reqwest/blocking", "dep:serde", "dep:serde_json", "dep:indicatif", "dep:num_cpus"] -tokio = ["online", "dep:tokio", "dep:futures"] diff --git a/candle-hub/src/api/mod.rs b/candle-hub/src/api/mod.rs deleted file mode 100644 index 779dc4f9..00000000 --- a/candle-hub/src/api/mod.rs +++ /dev/null @@ -1,6 +0,0 @@ -/// The asynchronous version of the API -#[cfg(feature = "tokio")] -pub mod tokio; - -/// The synchronous version of the API -pub mod sync; diff --git a/candle-hub/src/api/sync.rs b/candle-hub/src/api/sync.rs deleted file mode 100644 index 6efdbff8..00000000 --- a/candle-hub/src/api/sync.rs +++ /dev/null @@ -1,686 +0,0 @@ -use crate::{Cache, Repo}; -use indicatif::{ProgressBar, ProgressStyle}; -use rand::{distributions::Alphanumeric, thread_rng, Rng}; -use reqwest::{ - blocking::Client, - header::{ - HeaderMap, HeaderName, HeaderValue, InvalidHeaderValue, ToStrError, AUTHORIZATION, - CONTENT_RANGE, LOCATION, RANGE, USER_AGENT, - }, - redirect::Policy, - Error as ReqwestError, -}; -use serde::Deserialize; -use std::io::{Seek, SeekFrom, Write}; -use std::num::ParseIntError; -use std::path::{Component, Path, PathBuf}; -use thiserror::Error; - -/// Current version (used in user-agent) -const VERSION: &str = env!("CARGO_PKG_VERSION"); -/// Current name (used in user-agent) -const NAME: &str = env!("CARGO_PKG_NAME"); - -#[derive(Debug, Error)] -/// All errors the API can throw -pub enum ApiError { - /// Api expects certain header to be present in the results to derive some information - #[error("Header {0} is missing")] - MissingHeader(HeaderName), - - /// The header exists, but the value is not conform to what the Api expects. - #[error("Header {0} is invalid")] - InvalidHeader(HeaderName), - - /// The value cannot be used as a header during request header construction - #[error("Invalid header value {0}")] - InvalidHeaderValue(#[from] InvalidHeaderValue), - - /// The header value is not valid utf-8 - #[error("header value is not a string")] - ToStr(#[from] ToStrError), - - /// Error in the request - #[error("request error: {0}")] - RequestError(#[from] ReqwestError), - - /// Error parsing some range value - #[error("Cannot parse int")] - ParseIntError(#[from] ParseIntError), - - /// I/O Error - #[error("I/O error {0}")] - IoError(#[from] std::io::Error), - - /// We tried to download chunk too many times - #[error("Too many retries: {0}")] - TooManyRetries(Box), -} - -/// Siblings are simplified file descriptions of remote files on the hub -#[derive(Debug, Clone, Deserialize, PartialEq)] -pub struct Siblings { - /// The path within the repo. - pub rfilename: String, -} - -/// The description of the repo given by the hub -#[derive(Debug, Clone, Deserialize, PartialEq)] -pub struct ModelInfo { - /// See [`Siblings`] - pub siblings: Vec, -} - -/// Helper to create [`Api`] with all the options. -pub struct ApiBuilder { - endpoint: String, - cache: Cache, - url_template: String, - token: Option, - chunk_size: usize, - parallel_failures: usize, - max_retries: usize, - progress: bool, -} - -impl Default for ApiBuilder { - fn default() -> Self { - Self::new() - } -} - -impl ApiBuilder { - /// Default api builder - /// ``` - /// use candle_hub::api::sync::ApiBuilder; - /// let api = ApiBuilder::new().build().unwrap(); - /// ``` - pub fn new() -> Self { - let cache = Cache::default(); - let mut token_filename = cache.path().clone(); - token_filename.push(".token"); - let token = match std::fs::read_to_string(token_filename) { - Ok(token_content) => { - let token_content = token_content.trim(); - if !token_content.is_empty() { - Some(token_content.to_string()) - } else { - None - } - } - Err(_) => None, - }; - - let progress = true; - - Self { - endpoint: "https://huggingface.co".to_string(), - url_template: "{endpoint}/{repo_id}/resolve/{revision}/{filename}".to_string(), - cache, - token, - chunk_size: 10_000_000, - parallel_failures: 0, - max_retries: 0, - progress, - } - } - - /// Wether to show a progressbar - pub fn with_progress(mut self, progress: bool) -> Self { - self.progress = progress; - self - } - - /// Changes the location of the cache directory. Defaults is `~/.cache/huggingface/`. - pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self { - self.cache = Cache::new(cache_dir); - self - } - - fn build_headers(&self) -> Result { - let mut headers = HeaderMap::new(); - let user_agent = format!("unkown/None; {NAME}/{VERSION}; rust/unknown"); - headers.insert(USER_AGENT, HeaderValue::from_str(&user_agent)?); - if let Some(token) = &self.token { - headers.insert( - AUTHORIZATION, - HeaderValue::from_str(&format!("Bearer {token}"))?, - ); - } - Ok(headers) - } - - /// Consumes the builder and buids the final [`Api`] - pub fn build(self) -> Result { - let headers = self.build_headers()?; - let client = Client::builder().default_headers(headers.clone()).build()?; - let no_redirect_client = Client::builder() - .redirect(Policy::none()) - .default_headers(headers) - .build()?; - Ok(Api { - endpoint: self.endpoint, - url_template: self.url_template, - cache: self.cache, - client, - - no_redirect_client, - chunk_size: self.chunk_size, - parallel_failures: self.parallel_failures, - max_retries: self.max_retries, - progress: self.progress, - }) - } -} - -#[derive(Debug)] -struct Metadata { - commit_hash: String, - etag: String, - size: usize, -} - -/// The actual Api used to interacto with the hub. -/// You can inspect repos with [`Api::info`] -/// or download files with [`Api::download`] -pub struct Api { - endpoint: String, - url_template: String, - cache: Cache, - client: Client, - no_redirect_client: Client, - chunk_size: usize, - parallel_failures: usize, - max_retries: usize, - progress: bool, -} - -fn temp_filename() -> PathBuf { - let s: String = rand::thread_rng() - .sample_iter(&Alphanumeric) - .take(7) - .map(char::from) - .collect(); - let mut path = std::env::temp_dir(); - path.push(s); - path -} - -fn make_relative(src: &Path, dst: &Path) -> PathBuf { - let path = src; - let base = dst; - - if path.is_absolute() != base.is_absolute() { - panic!("This function is made to look at absolute paths only"); - } - let mut ita = path.components(); - let mut itb = base.components(); - - loop { - match (ita.next(), itb.next()) { - (Some(a), Some(b)) if a == b => (), - (some_a, _) => { - // Ignoring b, because 1 component is the filename - // for which we don't need to go back up for relative - // filename to work. - let mut new_path = PathBuf::new(); - for _ in itb { - new_path.push(Component::ParentDir); - } - if let Some(a) = some_a { - new_path.push(a); - for comp in ita { - new_path.push(comp); - } - } - return new_path; - } - } - } -} - -fn symlink_or_rename(src: &Path, dst: &Path) -> Result<(), std::io::Error> { - if dst.exists() { - return Ok(()); - } - - let src = make_relative(src, dst); - #[cfg(target_os = "windows")] - std::os::windows::fs::symlink_file(src, dst)?; - - #[cfg(target_family = "unix")] - std::os::unix::fs::symlink(src, dst)?; - - #[cfg(not(any(target_family = "unix", target_os = "windows")))] - std::fs::rename(src, dst)?; - - Ok(()) -} - -fn jitter() -> usize { - thread_rng().gen_range(0..=500) -} - -fn exponential_backoff(base_wait_time: usize, n: usize, max: usize) -> usize { - (base_wait_time + n.pow(2) + jitter()).min(max) -} - -impl Api { - /// Creates a default Api, for Api options See [`ApiBuilder`] - pub fn new() -> Result { - ApiBuilder::new().build() - } - - /// Get the fully qualified URL of the remote filename - /// ``` - /// # use candle_hub::{api::sync::Api, Repo}; - /// let api = Api::new().unwrap(); - /// let repo = Repo::model("gpt2".to_string()); - /// let url = api.url(&repo, "model.safetensors"); - /// assert_eq!(url, "https://huggingface.co/gpt2/resolve/main/model.safetensors"); - /// ``` - pub fn url(&self, repo: &Repo, filename: &str) -> String { - let endpoint = &self.endpoint; - let revision = &repo.url_revision(); - self.url_template - .replace("{endpoint}", endpoint) - .replace("{repo_id}", &repo.url()) - .replace("{revision}", revision) - .replace("{filename}", filename) - } - - /// Get the underlying api client - /// Allows for lower level access - pub fn client(&self) -> &Client { - &self.client - } - - fn metadata(&self, url: &str) -> Result { - let response = self - .no_redirect_client - .get(url) - .header(RANGE, "bytes=0-0") - .send()?; - let response = response.error_for_status()?; - let headers = response.headers(); - let header_commit = HeaderName::from_static("x-repo-commit"); - let header_linked_etag = HeaderName::from_static("x-linked-etag"); - let header_etag = HeaderName::from_static("etag"); - - let etag = match headers.get(&header_linked_etag) { - Some(etag) => etag, - None => headers - .get(&header_etag) - .ok_or(ApiError::MissingHeader(header_etag))?, - }; - // Cleaning extra quotes - let etag = etag.to_str()?.to_string().replace('"', ""); - let commit_hash = headers - .get(&header_commit) - .ok_or(ApiError::MissingHeader(header_commit))? - .to_str()? - .to_string(); - - // The response was redirected o S3 most likely which will - // know about the size of the file - let response = if response.status().is_redirection() { - self.client - .get(headers.get(LOCATION).unwrap().to_str()?.to_string()) - .header(RANGE, "bytes=0-0") - .send()? - } else { - response - }; - let headers = response.headers(); - let content_range = headers - .get(CONTENT_RANGE) - .ok_or(ApiError::MissingHeader(CONTENT_RANGE))? - .to_str()?; - - let size = content_range - .split('/') - .last() - .ok_or(ApiError::InvalidHeader(CONTENT_RANGE))? - .parse()?; - Ok(Metadata { - commit_hash, - etag, - size, - }) - } - - fn download_tempfile( - &self, - url: &str, - length: usize, - progressbar: Option, - ) -> Result { - let filename = temp_filename(); - - // Create the file and set everything properly - std::fs::File::create(&filename)?.set_len(length as u64)?; - - let chunk_size = self.chunk_size; - - let n_chunks = (length + chunk_size - 1) / chunk_size; - let n_threads = num_cpus::get(); - let chunks_per_thread = (n_chunks + n_threads - 1) / n_threads; - let handles = (0..n_threads).map(|thread_id| { - let url = url.to_string(); - let filename = filename.clone(); - let client = self.client.clone(); - let parallel_failures = self.parallel_failures; - let max_retries = self.max_retries; - let progress = progressbar.clone(); - std::thread::spawn(move || { - for chunk_id in chunks_per_thread * thread_id - ..std::cmp::min(chunks_per_thread * (thread_id + 1), n_chunks) - { - let start = chunk_id * chunk_size; - let stop = std::cmp::min(start + chunk_size - 1, length); - let mut chunk = Self::download_chunk(&client, &url, &filename, start, stop); - let mut i = 0; - if parallel_failures > 0 { - while let Err(dlerr) = chunk { - let wait_time = exponential_backoff(300, i, 10_000); - std::thread::sleep(std::time::Duration::from_millis(wait_time as u64)); - - chunk = Self::download_chunk(&client, &url, &filename, start, stop); - i += 1; - if i > max_retries { - return Err(ApiError::TooManyRetries(dlerr.into())); - } - } - } - if let Some(p) = &progress { - p.inc((stop - start) as u64); - } - chunk? - } - Ok(()) - }) - }); - - let results: Result, ApiError> = - handles.into_iter().flat_map(|h| h.join()).collect(); - - results?; - if let Some(p) = progressbar { - p.finish() - } - Ok(filename) - } - - fn download_chunk( - client: &Client, - url: &str, - filename: &PathBuf, - start: usize, - stop: usize, - ) -> Result<(), ApiError> { - // Process each socket concurrently. - let range = format!("bytes={start}-{stop}"); - let mut file = std::fs::OpenOptions::new().write(true).open(filename)?; - file.seek(SeekFrom::Start(start as u64))?; - let response = client - .get(url) - .header(RANGE, range) - .send()? - .error_for_status()?; - let content = response.bytes()?; - file.write_all(&content)?; - Ok(()) - } - - /// This will attempt the fetch the file locally first, then [`Api.download`] - /// if the file is not present. - /// ```no_run - /// use candle_hub::{api::sync::ApiBuilder, Repo}; - /// let api = ApiBuilder::new().build().unwrap(); - /// let repo = Repo::model("gpt2".to_string()); - /// let local_filename = api.get(&repo, "model.safetensors").unwrap(); - pub fn get(&self, repo: &Repo, filename: &str) -> Result { - if let Some(path) = self.cache.get(repo, filename) { - Ok(path) - } else { - self.download(repo, filename) - } - } - - /// Downloads a remote file (if not already present) into the cache directory - /// to be used locally. - /// This functions require internet access to verify if new versions of the file - /// exist, even if a file is already on disk at location. - /// ```no_run - /// # use candle_hub::{api::sync::ApiBuilder, Repo}; - /// let api = ApiBuilder::new().build().unwrap(); - /// let repo = Repo::model("gpt2".to_string()); - /// let local_filename = api.download(&repo, "model.safetensors").unwrap(); - /// ``` - pub fn download(&self, repo: &Repo, filename: &str) -> Result { - let url = self.url(repo, filename); - let metadata = self.metadata(&url)?; - - let blob_path = self.cache.blob_path(repo, &metadata.etag); - std::fs::create_dir_all(blob_path.parent().unwrap())?; - - let progressbar = if self.progress { - let progress = ProgressBar::new(metadata.size as u64); - progress.set_style( - ProgressStyle::with_template( - "{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})", - ) - .unwrap(), // .progress_chars("━ "), - ); - let maxlength = 30; - let message = if filename.len() > maxlength { - format!("..{}", &filename[filename.len() - maxlength..]) - } else { - filename.to_string() - }; - progress.set_message(message); - Some(progress) - } else { - None - }; - - let tmp_filename = self.download_tempfile(&url, metadata.size, progressbar)?; - - if std::fs::rename(&tmp_filename, &blob_path).is_err() { - // Renaming may fail if locations are different mount points - std::fs::File::create(&blob_path)?; - std::fs::copy(tmp_filename, &blob_path)?; - } - - let mut pointer_path = self.cache.pointer_path(repo, &metadata.commit_hash); - pointer_path.push(filename); - std::fs::create_dir_all(pointer_path.parent().unwrap()).ok(); - - symlink_or_rename(&blob_path, &pointer_path)?; - self.cache.create_ref(repo, &metadata.commit_hash)?; - - Ok(pointer_path) - } - - /// Get information about the Repo - /// ``` - /// use candle_hub::{api::sync::Api, Repo}; - /// let api = Api::new().unwrap(); - /// let repo = Repo::model("gpt2".to_string()); - /// api.info(&repo); - /// ``` - pub fn info(&self, repo: &Repo) -> Result { - let url = format!("{}/api/{}", self.endpoint, repo.api_url()); - let response = self.client.get(url).send()?; - let response = response.error_for_status()?; - - let model_info = response.json()?; - - Ok(model_info) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::RepoType; - use rand::{distributions::Alphanumeric, Rng}; - use sha256::try_digest; - - struct TempDir { - path: PathBuf, - } - - impl TempDir { - pub fn new() -> Self { - let s: String = rand::thread_rng() - .sample_iter(&Alphanumeric) - .take(7) - .map(char::from) - .collect(); - let mut path = std::env::temp_dir(); - path.push(s); - std::fs::create_dir(&path).unwrap(); - Self { path } - } - } - - impl Drop for TempDir { - fn drop(&mut self) { - std::fs::remove_dir_all(&self.path).unwrap() - } - } - - #[test] - fn simple() { - let tmp = TempDir::new(); - let api = ApiBuilder::new() - .with_progress(false) - .with_cache_dir(tmp.path.clone()) - .build() - .unwrap(); - let repo = Repo::new("julien-c/dummy-unknown".to_string(), RepoType::Model); - let downloaded_path = api.download(&repo, "config.json").unwrap(); - assert!(downloaded_path.exists()); - let val = try_digest(&*downloaded_path).unwrap(); - assert_eq!( - val, - "b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32" - ); - - // Make sure the file is now seeable without connection - let cache_path = api.cache.get(&repo, "config.json").unwrap(); - assert_eq!(cache_path, downloaded_path); - } - - #[test] - fn dataset() { - let tmp = TempDir::new(); - let api = ApiBuilder::new() - .with_progress(false) - .with_cache_dir(tmp.path.clone()) - .build() - .unwrap(); - let repo = Repo::with_revision( - "wikitext".to_string(), - RepoType::Dataset, - "refs/convert/parquet".to_string(), - ); - let downloaded_path = api - .download(&repo, "wikitext-103-v1/wikitext-test.parquet") - .unwrap(); - assert!(downloaded_path.exists()); - let val = try_digest(&*downloaded_path).unwrap(); - assert_eq!( - val, - "59ce09415ad8aa45a9e34f88cec2548aeb9de9a73fcda9f6b33a86a065f32b90" - ) - } - - #[test] - fn info() { - let tmp = TempDir::new(); - let api = ApiBuilder::new() - .with_progress(false) - .with_cache_dir(tmp.path.clone()) - .build() - .unwrap(); - let repo = Repo::with_revision( - "wikitext".to_string(), - RepoType::Dataset, - "refs/convert/parquet".to_string(), - ); - let model_info = api.info(&repo).unwrap(); - assert_eq!( - model_info, - ModelInfo { - siblings: vec![ - Siblings { - rfilename: ".gitattributes".to_string() - }, - Siblings { - rfilename: "wikitext-103-raw-v1/wikitext-test.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-103-raw-v1/wikitext-train-00000-of-00002.parquet" - .to_string() - }, - Siblings { - rfilename: "wikitext-103-raw-v1/wikitext-train-00001-of-00002.parquet" - .to_string() - }, - Siblings { - rfilename: "wikitext-103-raw-v1/wikitext-validation.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-103-v1/test/index.duckdb".to_string() - }, - Siblings { - rfilename: "wikitext-103-v1/validation/index.duckdb".to_string() - }, - Siblings { - rfilename: "wikitext-103-v1/wikitext-test.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-103-v1/wikitext-train-00000-of-00002.parquet" - .to_string() - }, - Siblings { - rfilename: "wikitext-103-v1/wikitext-train-00001-of-00002.parquet" - .to_string() - }, - Siblings { - rfilename: "wikitext-103-v1/wikitext-validation.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-2-raw-v1/test/index.duckdb".to_string() - }, - Siblings { - rfilename: "wikitext-2-raw-v1/train/index.duckdb".to_string() - }, - Siblings { - rfilename: "wikitext-2-raw-v1/validation/index.duckdb".to_string() - }, - Siblings { - rfilename: "wikitext-2-raw-v1/wikitext-test.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-2-raw-v1/wikitext-train.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-2-raw-v1/wikitext-validation.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-2-v1/wikitext-test.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-2-v1/wikitext-train.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-2-v1/wikitext-validation.parquet".to_string() - } - ], - } - ) - } -} diff --git a/candle-hub/src/api/tokio.rs b/candle-hub/src/api/tokio.rs deleted file mode 100644 index dc8f682e..00000000 --- a/candle-hub/src/api/tokio.rs +++ /dev/null @@ -1,723 +0,0 @@ -use crate::{Cache, Repo}; -use indicatif::{ProgressBar, ProgressStyle}; -use rand::{distributions::Alphanumeric, thread_rng, Rng}; -use reqwest::{ - header::{ - HeaderMap, HeaderName, HeaderValue, InvalidHeaderValue, ToStrError, AUTHORIZATION, - CONTENT_RANGE, LOCATION, RANGE, USER_AGENT, - }, - redirect::Policy, - Client, Error as ReqwestError, -}; -use serde::Deserialize; -use std::num::ParseIntError; -use std::path::{Component, Path, PathBuf}; -use std::sync::Arc; -use thiserror::Error; -use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom}; -use tokio::sync::{AcquireError, Semaphore, TryAcquireError}; - -/// Current version (used in user-agent) -const VERSION: &str = env!("CARGO_PKG_VERSION"); -/// Current name (used in user-agent) -const NAME: &str = env!("CARGO_PKG_NAME"); - -#[derive(Debug, Error)] -/// All errors the API can throw -pub enum ApiError { - /// Api expects certain header to be present in the results to derive some information - #[error("Header {0} is missing")] - MissingHeader(HeaderName), - - /// The header exists, but the value is not conform to what the Api expects. - #[error("Header {0} is invalid")] - InvalidHeader(HeaderName), - - /// The value cannot be used as a header during request header construction - #[error("Invalid header value {0}")] - InvalidHeaderValue(#[from] InvalidHeaderValue), - - /// The header value is not valid utf-8 - #[error("header value is not a string")] - ToStr(#[from] ToStrError), - - /// Error in the request - #[error("request error: {0}")] - RequestError(#[from] ReqwestError), - - /// Error parsing some range value - #[error("Cannot parse int")] - ParseIntError(#[from] ParseIntError), - - /// I/O Error - #[error("I/O error {0}")] - IoError(#[from] std::io::Error), - - /// We tried to download chunk too many times - #[error("Too many retries: {0}")] - TooManyRetries(Box), - - /// Semaphore cannot be acquired - #[error("Try acquire: {0}")] - TryAcquireError(#[from] TryAcquireError), - - /// Semaphore cannot be acquired - #[error("Acquire: {0}")] - AcquireError(#[from] AcquireError), - // /// Semaphore cannot be acquired - // #[error("Invalid Response: {0:?}")] - // InvalidResponse(Response), -} - -/// Siblings are simplified file descriptions of remote files on the hub -#[derive(Debug, Clone, Deserialize, PartialEq)] -pub struct Siblings { - /// The path within the repo. - pub rfilename: String, -} - -/// The description of the repo given by the hub -#[derive(Debug, Clone, Deserialize, PartialEq)] -pub struct ModelInfo { - /// See [`Siblings`] - pub siblings: Vec, -} - -/// Helper to create [`Api`] with all the options. -pub struct ApiBuilder { - endpoint: String, - cache: Cache, - url_template: String, - token: Option, - max_files: usize, - chunk_size: usize, - parallel_failures: usize, - max_retries: usize, - progress: bool, -} - -impl Default for ApiBuilder { - fn default() -> Self { - Self::new() - } -} - -impl ApiBuilder { - /// Default api builder - /// ``` - /// use candle_hub::api::tokio::ApiBuilder; - /// let api = ApiBuilder::new().build().unwrap(); - /// ``` - pub fn new() -> Self { - let cache = Cache::default(); - let mut token_filename = cache.path().clone(); - token_filename.push(".token"); - let token = match std::fs::read_to_string(token_filename) { - Ok(token_content) => { - let token_content = token_content.trim(); - if !token_content.is_empty() { - Some(token_content.to_string()) - } else { - None - } - } - Err(_) => None, - }; - - let progress = true; - - Self { - endpoint: "https://huggingface.co".to_string(), - url_template: "{endpoint}/{repo_id}/resolve/{revision}/{filename}".to_string(), - cache, - token, - max_files: num_cpus::get(), - chunk_size: 10_000_000, - parallel_failures: 0, - max_retries: 0, - progress, - } - } - - /// Wether to show a progressbar - pub fn with_progress(mut self, progress: bool) -> Self { - self.progress = progress; - self - } - - /// Changes the location of the cache directory. Defaults is `~/.cache/huggingface/`. - pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self { - self.cache = Cache::new(cache_dir); - self - } - - fn build_headers(&self) -> Result { - let mut headers = HeaderMap::new(); - let user_agent = format!("unkown/None; {NAME}/{VERSION}; rust/unknown"); - headers.insert(USER_AGENT, HeaderValue::from_str(&user_agent)?); - if let Some(token) = &self.token { - headers.insert( - AUTHORIZATION, - HeaderValue::from_str(&format!("Bearer {token}"))?, - ); - } - Ok(headers) - } - - /// Consumes the builder and buids the final [`Api`] - pub fn build(self) -> Result { - let headers = self.build_headers()?; - let client = Client::builder().default_headers(headers.clone()).build()?; - let no_redirect_client = Client::builder() - .redirect(Policy::none()) - .default_headers(headers) - .build()?; - Ok(Api { - endpoint: self.endpoint, - url_template: self.url_template, - cache: self.cache, - client, - - no_redirect_client, - max_files: self.max_files, - chunk_size: self.chunk_size, - parallel_failures: self.parallel_failures, - max_retries: self.max_retries, - progress: self.progress, - }) - } -} - -#[derive(Debug)] -struct Metadata { - commit_hash: String, - etag: String, - size: usize, -} - -/// The actual Api used to interacto with the hub. -/// You can inspect repos with [`Api::info`] -/// or download files with [`Api::download`] -pub struct Api { - endpoint: String, - url_template: String, - cache: Cache, - client: Client, - no_redirect_client: Client, - max_files: usize, - chunk_size: usize, - parallel_failures: usize, - max_retries: usize, - progress: bool, -} - -fn temp_filename() -> PathBuf { - let s: String = rand::thread_rng() - .sample_iter(&Alphanumeric) - .take(7) - .map(char::from) - .collect(); - let mut path = std::env::temp_dir(); - path.push(s); - path -} - -fn make_relative(src: &Path, dst: &Path) -> PathBuf { - let path = src; - let base = dst; - - if path.is_absolute() != base.is_absolute() { - panic!("This function is made to look at absolute paths only"); - } - let mut ita = path.components(); - let mut itb = base.components(); - - loop { - match (ita.next(), itb.next()) { - (Some(a), Some(b)) if a == b => (), - (some_a, _) => { - // Ignoring b, because 1 component is the filename - // for which we don't need to go back up for relative - // filename to work. - let mut new_path = PathBuf::new(); - for _ in itb { - new_path.push(Component::ParentDir); - } - if let Some(a) = some_a { - new_path.push(a); - for comp in ita { - new_path.push(comp); - } - } - return new_path; - } - } - } -} - -fn symlink_or_rename(src: &Path, dst: &Path) -> Result<(), std::io::Error> { - if dst.exists() { - return Ok(()); - } - - let src = make_relative(src, dst); - #[cfg(target_os = "windows")] - std::os::windows::fs::symlink_file(src, dst)?; - - #[cfg(target_family = "unix")] - std::os::unix::fs::symlink(src, dst)?; - - #[cfg(not(any(target_family = "unix", target_os = "windows")))] - std::fs::rename(src, dst)?; - - Ok(()) -} - -fn jitter() -> usize { - thread_rng().gen_range(0..=500) -} - -fn exponential_backoff(base_wait_time: usize, n: usize, max: usize) -> usize { - (base_wait_time + n.pow(2) + jitter()).min(max) -} - -impl Api { - /// Creates a default Api, for Api options See [`ApiBuilder`] - pub fn new() -> Result { - ApiBuilder::new().build() - } - - /// Get the fully qualified URL of the remote filename - /// ``` - /// # use candle_hub::{api::tokio::Api, Repo}; - /// let api = Api::new().unwrap(); - /// let repo = Repo::model("gpt2".to_string()); - /// let url = api.url(&repo, "model.safetensors"); - /// assert_eq!(url, "https://huggingface.co/gpt2/resolve/main/model.safetensors"); - /// ``` - pub fn url(&self, repo: &Repo, filename: &str) -> String { - let endpoint = &self.endpoint; - let revision = &repo.url_revision(); - self.url_template - .replace("{endpoint}", endpoint) - .replace("{repo_id}", &repo.url()) - .replace("{revision}", revision) - .replace("{filename}", filename) - } - - /// Get the underlying api client - /// Allows for lower level access - pub fn client(&self) -> &Client { - &self.client - } - - async fn metadata(&self, url: &str) -> Result { - let response = self - .no_redirect_client - .get(url) - .header(RANGE, "bytes=0-0") - .send() - .await?; - let response = response.error_for_status()?; - let headers = response.headers(); - let header_commit = HeaderName::from_static("x-repo-commit"); - let header_linked_etag = HeaderName::from_static("x-linked-etag"); - let header_etag = HeaderName::from_static("etag"); - - let etag = match headers.get(&header_linked_etag) { - Some(etag) => etag, - None => headers - .get(&header_etag) - .ok_or(ApiError::MissingHeader(header_etag))?, - }; - // Cleaning extra quotes - let etag = etag.to_str()?.to_string().replace('"', ""); - let commit_hash = headers - .get(&header_commit) - .ok_or(ApiError::MissingHeader(header_commit))? - .to_str()? - .to_string(); - - // The response was redirected o S3 most likely which will - // know about the size of the file - let response = if response.status().is_redirection() { - self.client - .get(headers.get(LOCATION).unwrap().to_str()?.to_string()) - .header(RANGE, "bytes=0-0") - .send() - .await? - } else { - response - }; - let headers = response.headers(); - let content_range = headers - .get(CONTENT_RANGE) - .ok_or(ApiError::MissingHeader(CONTENT_RANGE))? - .to_str()?; - - let size = content_range - .split('/') - .last() - .ok_or(ApiError::InvalidHeader(CONTENT_RANGE))? - .parse()?; - Ok(Metadata { - commit_hash, - etag, - size, - }) - } - - async fn download_tempfile( - &self, - url: &str, - length: usize, - progressbar: Option, - ) -> Result { - let mut handles = vec![]; - let semaphore = Arc::new(Semaphore::new(self.max_files)); - let parallel_failures_semaphore = Arc::new(Semaphore::new(self.parallel_failures)); - let filename = temp_filename(); - - // Create the file and set everything properly - tokio::fs::File::create(&filename) - .await? - .set_len(length as u64) - .await?; - - let chunk_size = self.chunk_size; - for start in (0..length).step_by(chunk_size) { - let url = url.to_string(); - let filename = filename.clone(); - let client = self.client.clone(); - - let stop = std::cmp::min(start + chunk_size - 1, length); - let permit = semaphore.clone().acquire_owned().await?; - let parallel_failures = self.parallel_failures; - let max_retries = self.max_retries; - let parallel_failures_semaphore = parallel_failures_semaphore.clone(); - let progress = progressbar.clone(); - handles.push(tokio::spawn(async move { - let mut chunk = Self::download_chunk(&client, &url, &filename, start, stop).await; - let mut i = 0; - if parallel_failures > 0 { - while let Err(dlerr) = chunk { - let parallel_failure_permit = - parallel_failures_semaphore.clone().try_acquire_owned()?; - - let wait_time = exponential_backoff(300, i, 10_000); - tokio::time::sleep(tokio::time::Duration::from_millis(wait_time as u64)) - .await; - - chunk = Self::download_chunk(&client, &url, &filename, start, stop).await; - i += 1; - if i > max_retries { - return Err(ApiError::TooManyRetries(dlerr.into())); - } - drop(parallel_failure_permit); - } - } - drop(permit); - if let Some(p) = progress { - p.inc((stop - start) as u64); - } - chunk - })); - } - - // Output the chained result - let results: Vec, tokio::task::JoinError>> = - futures::future::join_all(handles).await; - let results: Result<(), ApiError> = results.into_iter().flatten().collect(); - results?; - if let Some(p) = progressbar { - p.finish() - } - Ok(filename) - } - - async fn download_chunk( - client: &reqwest::Client, - url: &str, - filename: &PathBuf, - start: usize, - stop: usize, - ) -> Result<(), ApiError> { - // Process each socket concurrently. - let range = format!("bytes={start}-{stop}"); - let mut file = tokio::fs::OpenOptions::new() - .write(true) - .open(filename) - .await?; - file.seek(SeekFrom::Start(start as u64)).await?; - let response = client - .get(url) - .header(RANGE, range) - .send() - .await? - .error_for_status()?; - let content = response.bytes().await?; - file.write_all(&content).await?; - Ok(()) - } - - /// This will attempt the fetch the file locally first, then [`Api.download`] - /// if the file is not present. - /// ```no_run - /// # use candle_hub::{api::tokio::ApiBuilder, Repo}; - /// # tokio_test::block_on(async { - /// let api = ApiBuilder::new().build().unwrap(); - /// let repo = Repo::model("gpt2".to_string()); - /// let local_filename = api.get(&repo, "model.safetensors").await.unwrap(); - /// # }) - pub async fn get(&self, repo: &Repo, filename: &str) -> Result { - if let Some(path) = self.cache.get(repo, filename) { - Ok(path) - } else { - self.download(repo, filename).await - } - } - - /// Downloads a remote file (if not already present) into the cache directory - /// to be used locally. - /// This functions require internet access to verify if new versions of the file - /// exist, even if a file is already on disk at location. - /// ```no_run - /// # use candle_hub::{api::tokio::ApiBuilder, Repo}; - /// # tokio_test::block_on(async { - /// let api = ApiBuilder::new().build().unwrap(); - /// let repo = Repo::model("gpt2".to_string()); - /// let local_filename = api.download(&repo, "model.safetensors").await.unwrap(); - /// # }) - /// ``` - pub async fn download(&self, repo: &Repo, filename: &str) -> Result { - let url = self.url(repo, filename); - let metadata = self.metadata(&url).await?; - - let blob_path = self.cache.blob_path(repo, &metadata.etag); - std::fs::create_dir_all(blob_path.parent().unwrap())?; - - let progressbar = if self.progress { - let progress = ProgressBar::new(metadata.size as u64); - progress.set_style( - ProgressStyle::with_template( - "{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})", - ) - .unwrap(), // .progress_chars("━ "), - ); - let maxlength = 30; - let message = if filename.len() > maxlength { - format!("..{}", &filename[filename.len() - maxlength..]) - } else { - filename.to_string() - }; - progress.set_message(message); - Some(progress) - } else { - None - }; - - let tmp_filename = self - .download_tempfile(&url, metadata.size, progressbar) - .await?; - - if tokio::fs::rename(&tmp_filename, &blob_path).await.is_err() { - // Renaming may fail if locations are different mount points - std::fs::File::create(&blob_path)?; - tokio::fs::copy(tmp_filename, &blob_path).await?; - } - - let mut pointer_path = self.cache.pointer_path(repo, &metadata.commit_hash); - pointer_path.push(filename); - std::fs::create_dir_all(pointer_path.parent().unwrap()).ok(); - - symlink_or_rename(&blob_path, &pointer_path)?; - self.cache.create_ref(repo, &metadata.commit_hash)?; - - Ok(pointer_path) - } - - /// Get information about the Repo - /// ``` - /// # use candle_hub::{api::tokio::Api, Repo}; - /// # tokio_test::block_on(async { - /// let api = Api::new().unwrap(); - /// let repo = Repo::model("gpt2".to_string()); - /// api.info(&repo); - /// # }) - /// ``` - pub async fn info(&self, repo: &Repo) -> Result { - let url = format!("{}/api/{}", self.endpoint, repo.api_url()); - let response = self.client.get(url).send().await?; - let response = response.error_for_status()?; - - let model_info = response.json().await?; - - Ok(model_info) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::RepoType; - use rand::{distributions::Alphanumeric, Rng}; - use sha256::try_digest; - - struct TempDir { - path: PathBuf, - } - - impl TempDir { - pub fn new() -> Self { - let s: String = rand::thread_rng() - .sample_iter(&Alphanumeric) - .take(7) - .map(char::from) - .collect(); - let mut path = std::env::temp_dir(); - path.push(s); - std::fs::create_dir(&path).unwrap(); - Self { path } - } - } - - impl Drop for TempDir { - fn drop(&mut self) { - std::fs::remove_dir_all(&self.path).unwrap() - } - } - - #[tokio::test] - async fn simple() { - let tmp = TempDir::new(); - let api = ApiBuilder::new() - .with_progress(false) - .with_cache_dir(tmp.path.clone()) - .build() - .unwrap(); - let repo = Repo::new("julien-c/dummy-unknown".to_string(), RepoType::Model); - let downloaded_path = api.download(&repo, "config.json").await.unwrap(); - assert!(downloaded_path.exists()); - let val = try_digest(&*downloaded_path).unwrap(); - assert_eq!( - val, - "b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32" - ); - - // Make sure the file is now seeable without connection - let cache_path = api.cache.get(&repo, "config.json").unwrap(); - assert_eq!(cache_path, downloaded_path); - } - - #[tokio::test] - async fn dataset() { - let tmp = TempDir::new(); - let api = ApiBuilder::new() - .with_progress(false) - .with_cache_dir(tmp.path.clone()) - .build() - .unwrap(); - let repo = Repo::with_revision( - "wikitext".to_string(), - RepoType::Dataset, - "refs/convert/parquet".to_string(), - ); - let downloaded_path = api - .download(&repo, "wikitext-103-v1/wikitext-test.parquet") - .await - .unwrap(); - assert!(downloaded_path.exists()); - let val = try_digest(&*downloaded_path).unwrap(); - assert_eq!( - val, - "59ce09415ad8aa45a9e34f88cec2548aeb9de9a73fcda9f6b33a86a065f32b90" - ) - } - - #[tokio::test] - async fn info() { - let tmp = TempDir::new(); - let api = ApiBuilder::new() - .with_progress(false) - .with_cache_dir(tmp.path.clone()) - .build() - .unwrap(); - let repo = Repo::with_revision( - "wikitext".to_string(), - RepoType::Dataset, - "refs/convert/parquet".to_string(), - ); - let model_info = api.info(&repo).await.unwrap(); - assert_eq!( - model_info, - ModelInfo { - siblings: vec![ - Siblings { - rfilename: ".gitattributes".to_string() - }, - Siblings { - rfilename: "wikitext-103-raw-v1/wikitext-test.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-103-raw-v1/wikitext-train-00000-of-00002.parquet" - .to_string() - }, - Siblings { - rfilename: "wikitext-103-raw-v1/wikitext-train-00001-of-00002.parquet" - .to_string() - }, - Siblings { - rfilename: "wikitext-103-raw-v1/wikitext-validation.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-103-v1/test/index.duckdb".to_string() - }, - Siblings { - rfilename: "wikitext-103-v1/validation/index.duckdb".to_string() - }, - Siblings { - rfilename: "wikitext-103-v1/wikitext-test.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-103-v1/wikitext-train-00000-of-00002.parquet" - .to_string() - }, - Siblings { - rfilename: "wikitext-103-v1/wikitext-train-00001-of-00002.parquet" - .to_string() - }, - Siblings { - rfilename: "wikitext-103-v1/wikitext-validation.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-2-raw-v1/test/index.duckdb".to_string() - }, - Siblings { - rfilename: "wikitext-2-raw-v1/train/index.duckdb".to_string() - }, - Siblings { - rfilename: "wikitext-2-raw-v1/validation/index.duckdb".to_string() - }, - Siblings { - rfilename: "wikitext-2-raw-v1/wikitext-test.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-2-raw-v1/wikitext-train.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-2-raw-v1/wikitext-validation.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-2-v1/wikitext-test.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-2-v1/wikitext-train.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-2-v1/wikitext-validation.parquet".to_string() - } - ], - } - ) - } -} diff --git a/candle-hub/src/lib.rs b/candle-hub/src/lib.rs deleted file mode 100644 index 0de2006a..00000000 --- a/candle-hub/src/lib.rs +++ /dev/null @@ -1,197 +0,0 @@ -#![deny(missing_docs)] -//! This crates aims to emulate and be compatible with the -//! [huggingface_hub](https://github.com/huggingface/huggingface_hub/) python package. -//! -//! compatible means the Api should reuse the same files skipping downloads if -//! they are already present and whenever this crate downloads or modifies this cache -//! it should be consistent with [huggingface_hub](https://github.com/huggingface/huggingface_hub/) -//! -//! At this time only a limited subset of the functionality is present, the goal is to add new -//! features over time -use std::io::Write; -use std::path::PathBuf; - -/// The actual Api to interact with the hub. -#[cfg(feature = "online")] -pub mod api; - -/// The type of repo to interact with -#[derive(Debug, Clone, Copy)] -pub enum RepoType { - /// This is a model, usually it consists of weight files and some configuration - /// files - Model, - /// This is a dataset, usually contains data within parquet files - Dataset, - /// This is a space, usually a demo showcashing a given model or dataset - Space, -} - -/// A local struct used to fetch information from the cache folder. -pub struct Cache { - path: PathBuf, -} - -impl Cache { - /// Creates a new cache object location - pub fn new(path: PathBuf) -> Self { - Self { path } - } - - /// Creates a new cache object location - pub fn path(&self) -> &PathBuf { - &self.path - } - - /// This will get the location of the file within the cache for the remote - /// `filename`. Will return `None` if file is not already present in cache. - pub fn get(&self, repo: &Repo, filename: &str) -> Option { - let mut commit_path = self.path.clone(); - commit_path.push(repo.folder_name()); - commit_path.push("refs"); - commit_path.push(repo.revision()); - let commit_hash = std::fs::read_to_string(commit_path).ok()?; - let mut pointer_path = self.pointer_path(repo, &commit_hash); - pointer_path.push(filename); - if pointer_path.exists() { - Some(pointer_path) - } else { - None - } - } - - /// Creates a reference in the cache directory that points branches to the correct - /// commits within the blobs. - pub fn create_ref(&self, repo: &Repo, commit_hash: &str) -> Result<(), std::io::Error> { - let mut ref_path = self.path.clone(); - ref_path.push(repo.folder_name()); - ref_path.push("refs"); - ref_path.push(repo.revision()); - // Needs to be done like this because revision might contain `/` creating subfolders here. - std::fs::create_dir_all(ref_path.parent().unwrap())?; - let mut file1 = std::fs::OpenOptions::new() - .write(true) - .create(true) - .open(&ref_path)?; - file1.write_all(commit_hash.trim().as_bytes())?; - Ok(()) - } - - #[cfg(feature = "online")] - pub(crate) fn blob_path(&self, repo: &Repo, etag: &str) -> PathBuf { - let mut blob_path = self.path.clone(); - blob_path.push(repo.folder_name()); - blob_path.push("blobs"); - blob_path.push(etag); - blob_path - } - - pub(crate) fn pointer_path(&self, repo: &Repo, commit_hash: &str) -> PathBuf { - let mut pointer_path = self.path.clone(); - pointer_path.push(repo.folder_name()); - pointer_path.push("snapshots"); - pointer_path.push(commit_hash); - pointer_path - } -} - -impl Default for Cache { - fn default() -> Self { - let path = match std::env::var("HF_HOME") { - Ok(home) => home.into(), - Err(_) => { - let mut cache = dirs::home_dir().expect("Cache directory cannot be found"); - cache.push(".cache"); - cache.push("huggingface"); - cache.push("hub"); - cache - } - }; - Self::new(path) - } -} - -/// The representation of a repo on the hub. -#[allow(dead_code)] // Repo type unused in offline mode -pub struct Repo { - repo_id: String, - repo_type: RepoType, - revision: String, -} - -impl Repo { - /// Repo with the default branch ("main"). - pub fn new(repo_id: String, repo_type: RepoType) -> Self { - Self::with_revision(repo_id, repo_type, "main".to_string()) - } - - /// fully qualified Repo - pub fn with_revision(repo_id: String, repo_type: RepoType, revision: String) -> Self { - Self { - repo_id, - repo_type, - revision, - } - } - - /// Shortcut for [`Repo::new`] with [`RepoType::Model`] - pub fn model(repo_id: String) -> Self { - Self::new(repo_id, RepoType::Model) - } - - /// Shortcut for [`Repo::new`] with [`RepoType::Dataset`] - pub fn dataset(repo_id: String) -> Self { - Self::new(repo_id, RepoType::Dataset) - } - - /// Shortcut for [`Repo::new`] with [`RepoType::Space`] - pub fn space(repo_id: String) -> Self { - Self::new(repo_id, RepoType::Space) - } - - /// The normalized folder nameof the repo within the cache directory - pub fn folder_name(&self) -> String { - let prefix = match self.repo_type { - RepoType::Model => "models", - RepoType::Dataset => "datasets", - RepoType::Space => "spaces", - }; - format!("{prefix}--{}", self.repo_id).replace('/', "--") - } - - /// The revision - pub fn revision(&self) -> &str { - &self.revision - } - - /// The actual URL part of the repo - #[cfg(feature = "online")] - pub fn url(&self) -> String { - match self.repo_type { - RepoType::Model => self.repo_id.to_string(), - RepoType::Dataset => { - format!("datasets/{}", self.repo_id) - } - RepoType::Space => { - format!("spaces/{}", self.repo_id) - } - } - } - - /// Revision needs to be url escaped before being used in a URL - #[cfg(feature = "online")] - pub fn url_revision(&self) -> String { - self.revision.replace('/', "%2F") - } - - /// Used to compute the repo's url part when accessing the metadata of the repo - #[cfg(feature = "online")] - pub fn api_url(&self) -> String { - let prefix = match self.repo_type { - RepoType::Model => "models", - RepoType::Dataset => "datasets", - RepoType::Space => "spaces", - }; - format!("{prefix}/{}/revision/{}", self.repo_id, self.url_revision()) - } -} diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml index 01b41763..46847703 100644 --- a/candle-transformers/Cargo.toml +++ b/candle-transformers/Cargo.toml @@ -12,7 +12,7 @@ readme = "README.md" [dependencies] candle = { path = "../candle-core" } -candle-hub = { path = "../candle-hub" } +hf-hub = { workspace = true} candle-nn = { path = "../candle-nn" } intel-mkl-src = { workspace = true, optional = true, features = ["mkl-dynamic-lp64-iomp"]} tokenizers = { workspace = true, features = ["onig"] }