From 1a82bc50c9476517a04ed68dbc58f0bd1fcf8c2a Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 27 Jun 2023 12:07:34 +0200 Subject: [PATCH 1/3] [Tmp] Adding candle-hub --- Cargo.toml | 1 + candle-hub/.gitignore | 2 + candle-hub/Cargo.toml | 30 ++ candle-hub/src/api.rs | 694 ++++++++++++++++++++++++++++++++++++++++++ candle-hub/src/lib.rs | 113 +++++++ 5 files changed, 840 insertions(+) create mode 100644 candle-hub/.gitignore create mode 100644 candle-hub/Cargo.toml create mode 100644 candle-hub/src/api.rs create mode 100644 candle-hub/src/lib.rs diff --git a/Cargo.toml b/Cargo.toml index 4b00fc60..ce8cbb12 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ members = [ "candle-core", + "candle-hub", "candle-kernels", ] diff --git a/candle-hub/.gitignore b/candle-hub/.gitignore new file mode 100644 index 00000000..4fffb2f8 --- /dev/null +++ b/candle-hub/.gitignore @@ -0,0 +1,2 @@ +/target +/Cargo.lock diff --git a/candle-hub/Cargo.toml b/candle-hub/Cargo.toml new file mode 100644 index 00000000..47c19b17 --- /dev/null +++ b/candle-hub/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "candle-hub" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +dirs = "5.0.1" +fs2 = "0.4.3" +rand = "0.8.5" +thiserror = "1.0.40" +futures = { version = "0.3.28", optional = true } +reqwest = { version = "0.11.18", optional = true, features = ["json"] } +tokio = { version = "1.28.2", features = ["fs"], optional = true } +serde = { version = "1.0.164", features = ["derive"], optional = true } +serde_json = { version = "1.0.97", optional = true } +indicatif = { version = "0.17.5", optional = true } + +[dev-dependencies] +rand = "0.8.5" +sha256 = "1.1.4" +tokio = { version = "1.28.2", features = ["macros"] } +tokio-test = "0.4.2" + +[features] +default = ["online"] +online = ["reqwest", "tokio", "futures", "serde", "serde_json", "indicatif"] + + diff --git a/candle-hub/src/api.rs b/candle-hub/src/api.rs new file mode 100644 index 00000000..5136f631 --- /dev/null +++ b/candle-hub/src/api.rs @@ -0,0 +1,694 @@ +use crate::{Repo, NAME, VERSION}; +use fs2::FileExt; +use indicatif::{ProgressBar, ProgressStyle}; +use rand::{distributions::Alphanumeric, thread_rng, Rng}; +use reqwest::{ + header::{ + HeaderMap, HeaderName, HeaderValue, InvalidHeaderValue, ToStrError, AUTHORIZATION, + CONTENT_RANGE, LOCATION, RANGE, USER_AGENT, + }, + redirect::Policy, + Client, Error as ReqwestError, +}; +use serde::Deserialize; +use std::num::ParseIntError; +use std::path::{Component, Path, PathBuf}; +use std::sync::Arc; +use thiserror::Error; +use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom}; +use tokio::sync::{AcquireError, Semaphore, TryAcquireError}; + +#[derive(Debug, Error)] +/// All errors the API can throw +pub enum ApiError { + /// Api expects certain header to be present in the results to derive some information + #[error("Header {0} is missing")] + MissingHeader(HeaderName), + + /// The header exists, but the value is not conform to what the Api expects. + #[error("Header {0} is invalid")] + InvalidHeader(HeaderName), + + /// The value cannot be used as a header during request header construction + #[error("Invalid header value {0}")] + InvalidHeaderValue(#[from] InvalidHeaderValue), + + /// The header value is not valid utf-8 + #[error("header value is not a string")] + ToStr(#[from] ToStrError), + + /// Error in the request + #[error("request error: {0}")] + RequestError(#[from] ReqwestError), + + /// Error parsing some range value + #[error("Cannot parse int")] + ParseIntError(#[from] ParseIntError), + + /// I/O Error + #[error("I/O error {0}")] + IoError(#[from] std::io::Error), + + /// We tried to download chunk too many times + #[error("Too many retries: {0}")] + TooManyRetries(Box), + + /// Semaphore cannot be acquired + #[error("Try acquire: {0}")] + TryAcquireError(#[from] TryAcquireError), + + /// Semaphore cannot be acquired + #[error("Acquire: {0}")] + AcquireError(#[from] AcquireError), + // /// Semaphore cannot be acquired + // #[error("Invalid Response: {0:?}")] + // InvalidResponse(Response), +} + +/// Siblings are simplified file descriptions of remote files on the hub +#[derive(Debug, Clone, Deserialize, PartialEq)] +pub struct Siblings { + /// The path within the repo. + pub rfilename: String, +} + +/// The description of the repo given by the hub +#[derive(Debug, Clone, Deserialize, PartialEq)] +pub struct ModelInfo { + /// See [`Siblings`] + pub siblings: Vec, +} + +/// Helper to create [`Api`] with all the options. +pub struct ApiBuilder { + endpoint: String, + url_template: String, + cache_dir: PathBuf, + token: Option, + max_files: usize, + chunk_size: usize, + parallel_failures: usize, + max_retries: usize, + progress: bool, +} + +impl Default for ApiBuilder { + fn default() -> Self { + Self::new() + } +} + +impl ApiBuilder { + /// Default api builder + /// ``` + /// use hf_hub::api::ApiBuilder; + /// let api = ApiBuilder::new().build().unwrap(); + /// ``` + pub fn new() -> Self { + let cache_dir = match std::env::var("HF_HOME") { + Ok(home) => home.into(), + Err(_) => { + let mut cache = dirs::home_dir().expect("Cache directory cannot be found"); + cache.push(".cache"); + cache.push("huggingface"); + cache + } + }; + let mut token_filename = cache_dir.clone(); + token_filename.push(".token"); + let token = match std::fs::read_to_string(token_filename) { + Ok(token_content) => { + let token_content = token_content.trim(); + if !token_content.is_empty() { + Some(token_content.to_string()) + } else { + None + } + } + Err(_) => None, + }; + + let progress = true; + + Self { + endpoint: "https://huggingface.co".to_string(), + url_template: "{endpoint}/{repo_id}/resolve/{revision}/{filename}".to_string(), + cache_dir, + token, + max_files: 100, + chunk_size: 10_000_000, + parallel_failures: 0, + max_retries: 0, + progress, + } + } + + /// Wether to show a progressbar + pub fn with_progress(mut self, progress: bool) -> Self { + self.progress = progress; + self + } + + /// Changes the location of the cache directory. Defaults is `~/.cache/huggingface/`. + pub fn with_cache_dir(mut self, cache_dir: &Path) -> Self { + self.cache_dir = cache_dir.to_path_buf(); + self + } + + fn build_headers(&self) -> Result { + let mut headers = HeaderMap::new(); + let user_agent = format!("unkown/None; {NAME}/{VERSION}; rust/unknown"); + headers.insert(USER_AGENT, HeaderValue::from_str(&user_agent)?); + if let Some(token) = &self.token { + headers.insert( + AUTHORIZATION, + HeaderValue::from_str(&format!("Bearer {token}"))?, + ); + } + Ok(headers) + } + + /// Consumes the builder and buids the final [`Api`] + pub fn build(self) -> Result { + let headers = self.build_headers()?; + let client = Client::builder().default_headers(headers.clone()).build()?; + let no_redirect_client = Client::builder() + .redirect(Policy::none()) + .default_headers(headers) + .build()?; + Ok(Api { + endpoint: self.endpoint, + url_template: self.url_template, + cache_dir: self.cache_dir, + client, + + no_redirect_client, + max_files: self.max_files, + chunk_size: self.chunk_size, + parallel_failures: self.parallel_failures, + max_retries: self.max_retries, + progress: self.progress, + }) + } +} + +#[derive(Debug)] +struct Metadata { + commit_hash: String, + etag: String, + size: usize, +} + +/// The actual Api used to interacto with the hub. +/// You can inspect repos with [`Api::info`] +/// or download files with [`Api::download`] +pub struct Api { + endpoint: String, + url_template: String, + cache_dir: PathBuf, + client: Client, + no_redirect_client: Client, + max_files: usize, + chunk_size: usize, + parallel_failures: usize, + max_retries: usize, + progress: bool, +} + +fn temp_filename() -> PathBuf { + let s: String = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(7) + .map(char::from) + .collect(); + let mut path = std::env::temp_dir(); + path.push(s); + path +} + +fn make_relative(src: &Path, dst: &Path) -> PathBuf { + let path = src; + let base = dst; + + if path.is_absolute() != base.is_absolute() { + panic!("This function is made to look at absolute paths only"); + } + let mut ita = path.components(); + let mut itb = base.components(); + + loop { + match (ita.next(), itb.next()) { + (Some(a), Some(b)) if a == b => (), + (some_a, _) => { + // Ignoring b, because 1 component is the filename + // for which we don't need to go back up for relative + // filename to work. + let mut new_path = PathBuf::new(); + for _ in itb { + new_path.push(Component::ParentDir); + } + if let Some(a) = some_a { + new_path.push(a); + for comp in ita { + new_path.push(comp); + } + } + return new_path; + } + } + } +} + +fn symlink_or_rename(src: &Path, dst: &Path) -> Result<(), std::io::Error> { + if dst.exists() { + return Ok(()); + } + + let src = make_relative(src, dst); + #[cfg(target_os = "windows")] + std::os::windows::fs::symlink_file(src, dst)?; + + #[cfg(target_family = "unix")] + std::os::unix::fs::symlink(src, dst)?; + + #[cfg(not(any(target_family = "unix", target_os = "windows")))] + std::fs::rename(src, dst)?; + + Ok(()) +} + +fn jitter() -> usize { + thread_rng().gen_range(0..=500) +} + +fn exponential_backoff(base_wait_time: usize, n: usize, max: usize) -> usize { + (base_wait_time + n.pow(2) + jitter()).min(max) +} + +impl Api { + /// Creates a default Api, for Api options See [`ApiBuilder`] + pub fn new() -> Result { + ApiBuilder::new().build() + } + + /// Get the fully qualified URL of the remote filename + /// ``` + /// # use hf_hub::{api::Api, Repo}; + /// let api = Api::new().unwrap(); + /// let repo = Repo::model("gpt2".to_string()); + /// let url = api.url(&repo, "model.safetensors"); + /// assert_eq!(url, "https://huggingface.co/gpt2/resolve/main/model.safetensors"); + /// ``` + pub fn url(&self, repo: &Repo, filename: &str) -> String { + let endpoint = &self.endpoint; + let revision = &repo.url_revision(); + self.url_template + .replace("{endpoint}", endpoint) + .replace("{repo_id}", &repo.url()) + .replace("{revision}", revision) + .replace("{filename}", filename) + } + + /// Get the underlying api client + /// Allows for lower level access + pub fn client(&self) -> &Client { + &self.client + } + + async fn metadata(&self, url: &str) -> Result { + let response = self + .no_redirect_client + .get(url) + .header(RANGE, "bytes=0-0") + .send() + .await?; + let response = response.error_for_status()?; + let headers = response.headers(); + let header_commit = HeaderName::from_static("x-repo-commit"); + let header_linked_etag = HeaderName::from_static("x-linked-etag"); + let header_etag = HeaderName::from_static("etag"); + + let etag = match headers.get(&header_linked_etag) { + Some(etag) => etag, + None => headers + .get(&header_etag) + .ok_or(ApiError::MissingHeader(header_etag))?, + }; + // Cleaning extra quotes + let etag = etag.to_str()?.to_string().replace('"', ""); + let commit_hash = headers + .get(&header_commit) + .ok_or(ApiError::MissingHeader(header_commit))? + .to_str()? + .to_string(); + + // The response was redirected o S3 most likely which will + // know about the size of the file + let response = if response.status().is_redirection() { + self.client + .get(headers.get(LOCATION).unwrap().to_str()?.to_string()) + .header(RANGE, "bytes=0-0") + .send() + .await? + } else { + response + }; + let headers = response.headers(); + let content_range = headers + .get(CONTENT_RANGE) + .ok_or(ApiError::MissingHeader(CONTENT_RANGE))? + .to_str()?; + + let size = content_range + .split('/') + .last() + .ok_or(ApiError::InvalidHeader(CONTENT_RANGE))? + .parse()?; + Ok(Metadata { + commit_hash, + etag, + size, + }) + } + + async fn download_tempfile( + &self, + url: &str, + length: usize, + progressbar: Option, + ) -> Result { + let mut handles = vec![]; + let semaphore = Arc::new(Semaphore::new(self.max_files)); + let parallel_failures_semaphore = Arc::new(Semaphore::new(self.parallel_failures)); + let filename = temp_filename(); + + let chunk_size = self.chunk_size; + for start in (0..length).step_by(chunk_size) { + let url = url.to_string(); + let filename = filename.clone(); + let client = self.client.clone(); + + let stop = std::cmp::min(start + chunk_size - 1, length); + let permit = semaphore.clone().acquire_owned().await?; + let parallel_failures = self.parallel_failures; + let max_retries = self.max_retries; + let parallel_failures_semaphore = parallel_failures_semaphore.clone(); + let progress = progressbar.clone(); + handles.push(tokio::spawn(async move { + let mut chunk = Self::download_chunk(&client, &url, &filename, start, stop).await; + let mut i = 0; + if parallel_failures > 0 { + while let Err(dlerr) = chunk { + let parallel_failure_permit = + parallel_failures_semaphore.clone().try_acquire_owned()?; + + let wait_time = exponential_backoff(300, i, 10_000); + tokio::time::sleep(tokio::time::Duration::from_millis(wait_time as u64)) + .await; + + chunk = Self::download_chunk(&client, &url, &filename, start, stop).await; + i += 1; + if i > max_retries { + return Err(ApiError::TooManyRetries(dlerr.into())); + } + drop(parallel_failure_permit); + } + } + drop(permit); + if let Some(p) = progress { + p.inc((stop - start) as u64); + } + chunk + })); + } + + // Output the chained result + let results: Vec, tokio::task::JoinError>> = + futures::future::join_all(handles).await; + let results: Result<(), ApiError> = results.into_iter().flatten().collect(); + results?; + if let Some(p) = progressbar { + p.finish() + } + Ok(filename) + } + + async fn download_chunk( + client: &reqwest::Client, + url: &str, + filename: &PathBuf, + start: usize, + stop: usize, + ) -> Result<(), ApiError> { + // Process each socket concurrently. + let range = format!("bytes={start}-{stop}"); + let mut file = tokio::fs::OpenOptions::new() + .write(true) + .create(true) + .open(filename) + .await?; + file.seek(SeekFrom::Start(start as u64)).await?; + let response = client + .get(url) + .header(RANGE, range) + .send() + .await? + .error_for_status()?; + let content = response.bytes().await?; + file.write_all(&content).await?; + Ok(()) + } + + /// Downloads a remote file (if not already present) into the cache directory + /// to be used locally. + /// This functions require internet access to verify if new versions of the file + /// exist, even if a file is already on disk at location. + /// ```no_run + /// # use hf_hub::{api::ApiBuilder, Repo}; + /// # tokio_test::block_on(async { + /// let api = ApiBuilder::new().build().unwrap(); + /// let repo = Repo::model("gpt2".to_string()); + /// let local_filename = api.download(&repo, "model.safetensors").await.unwrap(); + /// # }) + /// ``` + pub async fn download(&self, repo: &Repo, filename: &str) -> Result { + let url = self.url(repo, filename); + let metadata = self.metadata(&url).await?; + + let mut folder = self.cache_dir.clone(); + folder.push(repo.folder_name()); + + let mut blob_path = folder.clone(); + blob_path.push("blobs"); + std::fs::create_dir_all(&blob_path).ok(); + blob_path.push(&metadata.etag); + + let file1 = std::fs::OpenOptions::new() + .read(true) + .write(true) + .create(true) + .open(&blob_path)?; + file1.lock_exclusive()?; + + let progressbar = if self.progress { + let progress = ProgressBar::new(metadata.size as u64); + progress.set_style( + ProgressStyle::with_template( + "{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})", + ) + .unwrap(), // .progress_chars("━ "), + ); + let maxlength = 30; + let message = if filename.len() > maxlength { + format!("..{}", &filename[filename.len() - maxlength..]) + } else { + filename.to_string() + }; + progress.set_message(message); + Some(progress) + } else { + None + }; + + let tmp_filename = self + .download_tempfile(&url, metadata.size, progressbar) + .await?; + std::fs::copy(tmp_filename, &blob_path)?; + + let mut pointer_path = folder.clone(); + pointer_path.push("snapshots"); + pointer_path.push(&metadata.commit_hash); + pointer_path.push(filename); + std::fs::create_dir_all(pointer_path.parent().unwrap()).ok(); + + symlink_or_rename(&blob_path, &pointer_path)?; + + Ok(pointer_path) + } + + /// Get information about the Repo + /// ``` + /// # use hf_hub::{api::Api, Repo}; + /// # tokio_test::block_on(async { + /// let api = Api::new().unwrap(); + /// let repo = Repo::model("gpt2".to_string()); + /// api.info(&repo); + /// # }) + /// ``` + pub async fn info(&self, repo: &Repo) -> Result { + let url = format!("{}/api/{}", self.endpoint, repo.api_url()); + let response = self.client.get(url).send().await?; + let response = response.error_for_status()?; + + let model_info = response.json().await?; + + Ok(model_info) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::RepoType; + use rand::{distributions::Alphanumeric, Rng}; + use sha256::try_digest; + + struct TempDir { + path: PathBuf, + } + + impl TempDir { + pub fn new() -> Self { + let s: String = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(7) + .map(char::from) + .collect(); + let mut path = std::env::temp_dir(); + path.push(s); + std::fs::create_dir(&path).unwrap(); + Self { path } + } + } + + impl Drop for TempDir { + fn drop(&mut self) { + std::fs::remove_dir_all(&self.path).unwrap() + } + } + + #[tokio::test] + async fn simple() { + let tmp = TempDir::new(); + let api = ApiBuilder::new() + .with_progress(false) + .with_cache_dir(&tmp.path) + .build() + .unwrap(); + let repo = Repo::new("julien-c/dummy-unknown".to_string(), RepoType::Model); + let downloaded_path = api.download(&repo, "config.json").await.unwrap(); + assert!(downloaded_path.exists()); + let val = try_digest(&*downloaded_path).unwrap(); + assert_eq!( + val, + "b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32" + ) + } + + #[tokio::test] + async fn dataset() { + let tmp = TempDir::new(); + let api = ApiBuilder::new() + .with_progress(false) + .with_cache_dir(&tmp.path) + .build() + .unwrap(); + let repo = Repo::with_revision( + "wikitext".to_string(), + RepoType::Dataset, + "refs/convert/parquet".to_string(), + ); + let downloaded_path = api + .download(&repo, "wikitext-103-v1/wikitext-test.parquet") + .await + .unwrap(); + assert!(downloaded_path.exists()); + let val = try_digest(&*downloaded_path).unwrap(); + assert_eq!( + val, + "59ce09415ad8aa45a9e34f88cec2548aeb9de9a73fcda9f6b33a86a065f32b90" + ) + } + + #[tokio::test] + async fn info() { + let tmp = TempDir::new(); + let api = ApiBuilder::new() + .with_progress(false) + .with_cache_dir(&tmp.path) + .build() + .unwrap(); + let repo = Repo::with_revision( + "wikitext".to_string(), + RepoType::Dataset, + "refs/convert/parquet".to_string(), + ); + let model_info = api.info(&repo).await.unwrap(); + assert_eq!( + model_info, + ModelInfo { + siblings: vec![ + Siblings { + rfilename: ".gitattributes".to_string(), + }, + Siblings { + rfilename: "wikitext-103-raw-v1/wikitext-test.parquet".to_string(), + }, + Siblings { + rfilename: "wikitext-103-raw-v1/wikitext-train-00000-of-00002.parquet" + .to_string(), + }, + Siblings { + rfilename: "wikitext-103-raw-v1/wikitext-train-00001-of-00002.parquet" + .to_string(), + }, + Siblings { + rfilename: "wikitext-103-raw-v1/wikitext-validation.parquet".to_string(), + }, + Siblings { + rfilename: "wikitext-103-v1/wikitext-test.parquet".to_string(), + }, + Siblings { + rfilename: "wikitext-103-v1/wikitext-train-00000-of-00002.parquet" + .to_string(), + }, + Siblings { + rfilename: "wikitext-103-v1/wikitext-train-00001-of-00002.parquet" + .to_string(), + }, + Siblings { + rfilename: "wikitext-103-v1/wikitext-validation.parquet".to_string(), + }, + Siblings { + rfilename: "wikitext-2-raw-v1/wikitext-test.parquet".to_string(), + }, + Siblings { + rfilename: "wikitext-2-raw-v1/wikitext-train.parquet".to_string(), + }, + Siblings { + rfilename: "wikitext-2-raw-v1/wikitext-validation.parquet".to_string(), + }, + Siblings { + rfilename: "wikitext-2-v1/wikitext-test.parquet".to_string(), + }, + Siblings { + rfilename: "wikitext-2-v1/wikitext-train.parquet".to_string(), + }, + Siblings { + rfilename: "wikitext-2-v1/wikitext-validation.parquet".to_string(), + }, + ], + } + ) + } +} diff --git a/candle-hub/src/lib.rs b/candle-hub/src/lib.rs new file mode 100644 index 00000000..f72b1a0b --- /dev/null +++ b/candle-hub/src/lib.rs @@ -0,0 +1,113 @@ +#![deny(missing_docs)] +//! This crates aims to emulate and be compatible with the +//! [huggingface_hub](https://github.com/huggingface/huggingface_hub/) python package. +//! +//! compatible means the Api should reuse the same files skipping downloads if +//! they are already present and whenever this crate downloads or modifies this cache +//! it should be consistent with [huggingface_hub](https://github.com/huggingface/huggingface_hub/) +//! +//! At this time only a limited subset of the functionality is present, the goal is to add new +//! features over time + +/// The actual Api to interact with the hub. +#[cfg(feature = "online")] +pub mod api; + +/// Current version (used in user-agent) +const VERSION: &str = env!("CARGO_PKG_VERSION"); +/// Current name (used in user-agent) +const NAME: &str = env!("CARGO_PKG_NAME"); + +/// The type of repo to interact with +#[derive(Debug, Clone, Copy)] +pub enum RepoType { + /// This is a model, usually it consists of weight files and some configuration + /// files + Model, + /// This is a dataset, usually contains data within parquet files + Dataset, + /// This is a space, usually a demo showcashing a given model or dataset + Space, +} + +/// The representation of a repo on the hub. +pub struct Repo { + repo_id: String, + repo_type: RepoType, + revision: String, +} + +impl Repo { + /// Repo with the default branch ("main"). + pub fn new(repo_id: String, repo_type: RepoType) -> Self { + Self::with_revision(repo_id, repo_type, "main".to_string()) + } + + /// fully qualified Repo + pub fn with_revision(repo_id: String, repo_type: RepoType, revision: String) -> Self { + Self { + repo_id, + repo_type, + revision, + } + } + + /// Shortcut for [`Repo::new`] with [`RepoType::Model`] + pub fn model(repo_id: String) -> Self { + Self::new(repo_id, RepoType::Model) + } + + /// Shortcut for [`Repo::new`] with [`RepoType::Dataset`] + pub fn dataset(repo_id: String) -> Self { + Self::new(repo_id, RepoType::Dataset) + } + + /// Shortcut for [`Repo::new`] with [`RepoType::Space`] + pub fn space(repo_id: String) -> Self { + Self::new(repo_id, RepoType::Space) + } + + /// The normalized folder nameof the repo within the cache directory + pub fn folder_name(&self) -> String { + match self.repo_type { + RepoType::Model => self.repo_id.replace('/', "--"), + RepoType::Dataset => { + format!("datasets/{}", self.repo_id.replace('/', "--")) + } + RepoType::Space => { + format!("spaces/{}", self.repo_id.replace('/', "--")) + } + } + } + + /// The actual URL part of the repo + #[cfg(feature = "online")] + pub fn url(&self) -> String { + match self.repo_type { + RepoType::Model => self.repo_id.to_string(), + RepoType::Dataset => { + format!("datasets/{}", self.repo_id) + } + RepoType::Space => { + format!("spaces/{}", self.repo_id) + } + } + } + + /// Revision needs to be url escaped before being used in a URL + #[cfg(feature = "online")] + pub fn url_revision(&self) -> String { + self.revision.replace('/', "%2F") + } + + /// Used to compute the repo's url part when accessing the metadata of the repo + #[cfg(feature = "online")] + pub fn api_url(&self) -> String { + let prefix = match self.repo_type { + RepoType::Model => "models", + RepoType::Dataset => "datasets", + RepoType::Space => "spaces", + }; + format!("{prefix}/{}/revision/{}", self.repo_id, self.url_revision()) + } +} From 75e090583254a70111b51f54560b762874bc9dd8 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 27 Jun 2023 13:35:57 +0200 Subject: [PATCH 2/3] Adding fully offline version. --- candle-hub/src/api.rs | 77 +++++++++++++++++++---------------- candle-hub/src/lib.rs | 95 +++++++++++++++++++++++++++++++++++++++---- 2 files changed, 128 insertions(+), 44 deletions(-) diff --git a/candle-hub/src/api.rs b/candle-hub/src/api.rs index 5136f631..1df54e28 100644 --- a/candle-hub/src/api.rs +++ b/candle-hub/src/api.rs @@ -1,4 +1,4 @@ -use crate::{Repo, NAME, VERSION}; +use crate::{Cache, Repo, NAME, VERSION}; use fs2::FileExt; use indicatif::{ProgressBar, ProgressStyle}; use rand::{distributions::Alphanumeric, thread_rng, Rng}; @@ -82,8 +82,8 @@ pub struct ModelInfo { /// Helper to create [`Api`] with all the options. pub struct ApiBuilder { endpoint: String, + cache: Cache, url_template: String, - cache_dir: PathBuf, token: Option, max_files: usize, chunk_size: usize, @@ -101,20 +101,12 @@ impl Default for ApiBuilder { impl ApiBuilder { /// Default api builder /// ``` - /// use hf_hub::api::ApiBuilder; + /// use candle_hub::api::ApiBuilder; /// let api = ApiBuilder::new().build().unwrap(); /// ``` pub fn new() -> Self { - let cache_dir = match std::env::var("HF_HOME") { - Ok(home) => home.into(), - Err(_) => { - let mut cache = dirs::home_dir().expect("Cache directory cannot be found"); - cache.push(".cache"); - cache.push("huggingface"); - cache - } - }; - let mut token_filename = cache_dir.clone(); + let cache = Cache::default(); + let mut token_filename = cache.path().clone(); token_filename.push(".token"); let token = match std::fs::read_to_string(token_filename) { Ok(token_content) => { @@ -133,7 +125,7 @@ impl ApiBuilder { Self { endpoint: "https://huggingface.co".to_string(), url_template: "{endpoint}/{repo_id}/resolve/{revision}/{filename}".to_string(), - cache_dir, + cache, token, max_files: 100, chunk_size: 10_000_000, @@ -150,8 +142,8 @@ impl ApiBuilder { } /// Changes the location of the cache directory. Defaults is `~/.cache/huggingface/`. - pub fn with_cache_dir(mut self, cache_dir: &Path) -> Self { - self.cache_dir = cache_dir.to_path_buf(); + pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self { + self.cache = Cache::new(cache_dir); self } @@ -179,7 +171,7 @@ impl ApiBuilder { Ok(Api { endpoint: self.endpoint, url_template: self.url_template, - cache_dir: self.cache_dir, + cache: self.cache, client, no_redirect_client, @@ -205,7 +197,7 @@ struct Metadata { pub struct Api { endpoint: String, url_template: String, - cache_dir: PathBuf, + cache: Cache, client: Client, no_redirect_client: Client, max_files: usize, @@ -293,7 +285,7 @@ impl Api { /// Get the fully qualified URL of the remote filename /// ``` - /// # use hf_hub::{api::Api, Repo}; + /// # use candle_hub::{api::Api, Repo}; /// let api = Api::new().unwrap(); /// let repo = Repo::model("gpt2".to_string()); /// let url = api.url(&repo, "model.safetensors"); @@ -459,12 +451,29 @@ impl Api { Ok(()) } + /// This will attempt the fetch the file locally first, then [`Api.download`] + /// if the file is not present. + /// ```no_run + /// # use candle_hub::{api::ApiBuilder, Repo}; + /// # tokio_test::block_on(async { + /// let api = ApiBuilder::new().build().unwrap(); + /// let repo = Repo::model("gpt2".to_string()); + /// let local_filename = api.get(&repo, "model.safetensors").await.unwrap(); + /// # }) + pub async fn get(&self, repo: &Repo, filename: &str) -> Result { + if let Some(path) = self.cache.get(repo, filename) { + Ok(path) + } else { + self.download(repo, filename).await + } + } + /// Downloads a remote file (if not already present) into the cache directory /// to be used locally. /// This functions require internet access to verify if new versions of the file /// exist, even if a file is already on disk at location. /// ```no_run - /// # use hf_hub::{api::ApiBuilder, Repo}; + /// # use candle_hub::{api::ApiBuilder, Repo}; /// # tokio_test::block_on(async { /// let api = ApiBuilder::new().build().unwrap(); /// let repo = Repo::model("gpt2".to_string()); @@ -475,13 +484,8 @@ impl Api { let url = self.url(repo, filename); let metadata = self.metadata(&url).await?; - let mut folder = self.cache_dir.clone(); - folder.push(repo.folder_name()); - - let mut blob_path = folder.clone(); - blob_path.push("blobs"); - std::fs::create_dir_all(&blob_path).ok(); - blob_path.push(&metadata.etag); + let blob_path = self.cache.blob_path(repo, &metadata.etag); + std::fs::create_dir_all(blob_path.parent().unwrap())?; let file1 = std::fs::OpenOptions::new() .read(true) @@ -515,20 +519,19 @@ impl Api { .await?; std::fs::copy(tmp_filename, &blob_path)?; - let mut pointer_path = folder.clone(); - pointer_path.push("snapshots"); - pointer_path.push(&metadata.commit_hash); + let mut pointer_path = self.cache.pointer_path(repo, &metadata.commit_hash); pointer_path.push(filename); std::fs::create_dir_all(pointer_path.parent().unwrap()).ok(); symlink_or_rename(&blob_path, &pointer_path)?; + self.cache.create_ref(repo, &metadata.commit_hash)?; Ok(pointer_path) } /// Get information about the Repo /// ``` - /// # use hf_hub::{api::Api, Repo}; + /// # use candle_hub::{api::Api, Repo}; /// # tokio_test::block_on(async { /// let api = Api::new().unwrap(); /// let repo = Repo::model("gpt2".to_string()); @@ -582,7 +585,7 @@ mod tests { let tmp = TempDir::new(); let api = ApiBuilder::new() .with_progress(false) - .with_cache_dir(&tmp.path) + .with_cache_dir(tmp.path.clone()) .build() .unwrap(); let repo = Repo::new("julien-c/dummy-unknown".to_string(), RepoType::Model); @@ -592,7 +595,11 @@ mod tests { assert_eq!( val, "b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32" - ) + ); + + // Make sure the file is now seeable without connection + let cache_path = api.cache.get(&repo, "config.json").unwrap(); + assert_eq!(cache_path, downloaded_path); } #[tokio::test] @@ -600,7 +607,7 @@ mod tests { let tmp = TempDir::new(); let api = ApiBuilder::new() .with_progress(false) - .with_cache_dir(&tmp.path) + .with_cache_dir(tmp.path.clone()) .build() .unwrap(); let repo = Repo::with_revision( @@ -625,7 +632,7 @@ mod tests { let tmp = TempDir::new(); let api = ApiBuilder::new() .with_progress(false) - .with_cache_dir(&tmp.path) + .with_cache_dir(tmp.path.clone()) .build() .unwrap(); let repo = Repo::with_revision( diff --git a/candle-hub/src/lib.rs b/candle-hub/src/lib.rs index f72b1a0b..ceac9718 100644 --- a/candle-hub/src/lib.rs +++ b/candle-hub/src/lib.rs @@ -8,6 +8,8 @@ //! //! At this time only a limited subset of the functionality is present, the goal is to add new //! features over time +use std::io::Write; +use std::path::PathBuf; /// The actual Api to interact with the hub. #[cfg(feature = "online")] @@ -30,6 +32,84 @@ pub enum RepoType { Space, } +/// A local struct used to fetch information from the cache folder. +pub struct Cache { + path: PathBuf, +} + +impl Cache { + /// Creates a new cache object location + pub fn new(path: PathBuf) -> Self { + Self { path } + } + + /// Creates a new cache object location + pub fn path(&self) -> &PathBuf { + &self.path + } + + /// This will get the location of the file within the cache for the remote + /// `filename`. Will return `None` if file is not already present in cache. + pub fn get(&self, repo: &Repo, filename: &str) -> Option { + let mut commit_path = self.path.clone(); + commit_path.push(repo.folder_name()); + commit_path.push("refs"); + commit_path.push(repo.revision()); + let commit_hash = std::fs::read_to_string(commit_path).ok()?; + let mut pointer_path = self.pointer_path(repo, &commit_hash); + pointer_path.push(filename); + Some(pointer_path) + } + + /// Creates a reference in the cache directory that points branches to the correct + /// commits within the blobs. + pub fn create_ref(&self, repo: &Repo, commit_hash: &str) -> Result<(), std::io::Error> { + let mut ref_path = self.path.clone(); + ref_path.push(repo.folder_name()); + ref_path.push("refs"); + ref_path.push(repo.revision()); + // Needs to be done like this because revision might contain `/` creating subfolders here. + std::fs::create_dir_all(ref_path.parent().unwrap())?; + let mut file1 = std::fs::OpenOptions::new() + .write(true) + .create(true) + .open(&ref_path)?; + file1.write_all(commit_hash.trim().as_bytes())?; + Ok(()) + } + + pub(crate) fn blob_path(&self, repo: &Repo, etag: &str) -> PathBuf { + let mut blob_path = self.path.clone(); + blob_path.push(repo.folder_name()); + blob_path.push("blobs"); + blob_path.push(etag); + blob_path + } + + pub(crate) fn pointer_path(&self, repo: &Repo, commit_hash: &str) -> PathBuf { + let mut pointer_path = self.path.clone(); + pointer_path.push(repo.folder_name()); + pointer_path.push("snapshots"); + pointer_path.push(commit_hash); + pointer_path + } +} + +impl Default for Cache { + fn default() -> Self { + let path = match std::env::var("HF_HOME") { + Ok(home) => home.into(), + Err(_) => { + let mut cache = dirs::home_dir().expect("Cache directory cannot be found"); + cache.push(".cache"); + cache.push("huggingface"); + cache + } + }; + Self::new(path) + } +} + /// The representation of a repo on the hub. pub struct Repo { repo_id: String, @@ -69,15 +149,12 @@ impl Repo { /// The normalized folder nameof the repo within the cache directory pub fn folder_name(&self) -> String { - match self.repo_type { - RepoType::Model => self.repo_id.replace('/', "--"), - RepoType::Dataset => { - format!("datasets/{}", self.repo_id.replace('/', "--")) - } - RepoType::Space => { - format!("spaces/{}", self.repo_id.replace('/', "--")) - } - } + self.repo_id.replace('/', "--") + } + + /// The revision + pub fn revision(&self) -> &str { + &self.revision } /// The actual URL part of the repo From 70a90a14655638cc23d92c9dae91a35271bf4016 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 27 Jun 2023 14:04:20 +0200 Subject: [PATCH 3/3] Clippy without features. --- candle-hub/src/api.rs | 7 ++++++- candle-hub/src/lib.rs | 7 ++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/candle-hub/src/api.rs b/candle-hub/src/api.rs index 1df54e28..30af19c6 100644 --- a/candle-hub/src/api.rs +++ b/candle-hub/src/api.rs @@ -1,4 +1,4 @@ -use crate::{Cache, Repo, NAME, VERSION}; +use crate::{Cache, Repo}; use fs2::FileExt; use indicatif::{ProgressBar, ProgressStyle}; use rand::{distributions::Alphanumeric, thread_rng, Rng}; @@ -18,6 +18,11 @@ use thiserror::Error; use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom}; use tokio::sync::{AcquireError, Semaphore, TryAcquireError}; +/// Current version (used in user-agent) +const VERSION: &str = env!("CARGO_PKG_VERSION"); +/// Current name (used in user-agent) +const NAME: &str = env!("CARGO_PKG_NAME"); + #[derive(Debug, Error)] /// All errors the API can throw pub enum ApiError { diff --git a/candle-hub/src/lib.rs b/candle-hub/src/lib.rs index ceac9718..fbfef5e7 100644 --- a/candle-hub/src/lib.rs +++ b/candle-hub/src/lib.rs @@ -15,11 +15,6 @@ use std::path::PathBuf; #[cfg(feature = "online")] pub mod api; -/// Current version (used in user-agent) -const VERSION: &str = env!("CARGO_PKG_VERSION"); -/// Current name (used in user-agent) -const NAME: &str = env!("CARGO_PKG_NAME"); - /// The type of repo to interact with #[derive(Debug, Clone, Copy)] pub enum RepoType { @@ -78,6 +73,7 @@ impl Cache { Ok(()) } + #[cfg(feature = "online")] pub(crate) fn blob_path(&self, repo: &Repo, etag: &str) -> PathBuf { let mut blob_path = self.path.clone(); blob_path.push(repo.folder_name()); @@ -111,6 +107,7 @@ impl Default for Cache { } /// The representation of a repo on the hub. +#[allow(dead_code)] // Repo type unused in offline mode pub struct Repo { repo_id: String, repo_type: RepoType,