From 75e090583254a70111b51f54560b762874bc9dd8 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 27 Jun 2023 13:35:57 +0200 Subject: [PATCH] Adding fully offline version. --- candle-hub/src/api.rs | 77 +++++++++++++++++++---------------- candle-hub/src/lib.rs | 95 +++++++++++++++++++++++++++++++++++++++---- 2 files changed, 128 insertions(+), 44 deletions(-) diff --git a/candle-hub/src/api.rs b/candle-hub/src/api.rs index 5136f631..1df54e28 100644 --- a/candle-hub/src/api.rs +++ b/candle-hub/src/api.rs @@ -1,4 +1,4 @@ -use crate::{Repo, NAME, VERSION}; +use crate::{Cache, Repo, NAME, VERSION}; use fs2::FileExt; use indicatif::{ProgressBar, ProgressStyle}; use rand::{distributions::Alphanumeric, thread_rng, Rng}; @@ -82,8 +82,8 @@ pub struct ModelInfo { /// Helper to create [`Api`] with all the options. pub struct ApiBuilder { endpoint: String, + cache: Cache, url_template: String, - cache_dir: PathBuf, token: Option, max_files: usize, chunk_size: usize, @@ -101,20 +101,12 @@ impl Default for ApiBuilder { impl ApiBuilder { /// Default api builder /// ``` - /// use hf_hub::api::ApiBuilder; + /// use candle_hub::api::ApiBuilder; /// let api = ApiBuilder::new().build().unwrap(); /// ``` pub fn new() -> Self { - let cache_dir = match std::env::var("HF_HOME") { - Ok(home) => home.into(), - Err(_) => { - let mut cache = dirs::home_dir().expect("Cache directory cannot be found"); - cache.push(".cache"); - cache.push("huggingface"); - cache - } - }; - let mut token_filename = cache_dir.clone(); + let cache = Cache::default(); + let mut token_filename = cache.path().clone(); token_filename.push(".token"); let token = match std::fs::read_to_string(token_filename) { Ok(token_content) => { @@ -133,7 +125,7 @@ impl ApiBuilder { Self { endpoint: "https://huggingface.co".to_string(), url_template: "{endpoint}/{repo_id}/resolve/{revision}/{filename}".to_string(), - cache_dir, + cache, token, max_files: 100, chunk_size: 10_000_000, @@ -150,8 +142,8 @@ impl ApiBuilder { } /// Changes the location of the cache directory. Defaults is `~/.cache/huggingface/`. - pub fn with_cache_dir(mut self, cache_dir: &Path) -> Self { - self.cache_dir = cache_dir.to_path_buf(); + pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self { + self.cache = Cache::new(cache_dir); self } @@ -179,7 +171,7 @@ impl ApiBuilder { Ok(Api { endpoint: self.endpoint, url_template: self.url_template, - cache_dir: self.cache_dir, + cache: self.cache, client, no_redirect_client, @@ -205,7 +197,7 @@ struct Metadata { pub struct Api { endpoint: String, url_template: String, - cache_dir: PathBuf, + cache: Cache, client: Client, no_redirect_client: Client, max_files: usize, @@ -293,7 +285,7 @@ impl Api { /// Get the fully qualified URL of the remote filename /// ``` - /// # use hf_hub::{api::Api, Repo}; + /// # use candle_hub::{api::Api, Repo}; /// let api = Api::new().unwrap(); /// let repo = Repo::model("gpt2".to_string()); /// let url = api.url(&repo, "model.safetensors"); @@ -459,12 +451,29 @@ impl Api { Ok(()) } + /// This will attempt the fetch the file locally first, then [`Api.download`] + /// if the file is not present. + /// ```no_run + /// # use candle_hub::{api::ApiBuilder, Repo}; + /// # tokio_test::block_on(async { + /// let api = ApiBuilder::new().build().unwrap(); + /// let repo = Repo::model("gpt2".to_string()); + /// let local_filename = api.get(&repo, "model.safetensors").await.unwrap(); + /// # }) + pub async fn get(&self, repo: &Repo, filename: &str) -> Result { + if let Some(path) = self.cache.get(repo, filename) { + Ok(path) + } else { + self.download(repo, filename).await + } + } + /// Downloads a remote file (if not already present) into the cache directory /// to be used locally. /// This functions require internet access to verify if new versions of the file /// exist, even if a file is already on disk at location. /// ```no_run - /// # use hf_hub::{api::ApiBuilder, Repo}; + /// # use candle_hub::{api::ApiBuilder, Repo}; /// # tokio_test::block_on(async { /// let api = ApiBuilder::new().build().unwrap(); /// let repo = Repo::model("gpt2".to_string()); @@ -475,13 +484,8 @@ impl Api { let url = self.url(repo, filename); let metadata = self.metadata(&url).await?; - let mut folder = self.cache_dir.clone(); - folder.push(repo.folder_name()); - - let mut blob_path = folder.clone(); - blob_path.push("blobs"); - std::fs::create_dir_all(&blob_path).ok(); - blob_path.push(&metadata.etag); + let blob_path = self.cache.blob_path(repo, &metadata.etag); + std::fs::create_dir_all(blob_path.parent().unwrap())?; let file1 = std::fs::OpenOptions::new() .read(true) @@ -515,20 +519,19 @@ impl Api { .await?; std::fs::copy(tmp_filename, &blob_path)?; - let mut pointer_path = folder.clone(); - pointer_path.push("snapshots"); - pointer_path.push(&metadata.commit_hash); + let mut pointer_path = self.cache.pointer_path(repo, &metadata.commit_hash); pointer_path.push(filename); std::fs::create_dir_all(pointer_path.parent().unwrap()).ok(); symlink_or_rename(&blob_path, &pointer_path)?; + self.cache.create_ref(repo, &metadata.commit_hash)?; Ok(pointer_path) } /// Get information about the Repo /// ``` - /// # use hf_hub::{api::Api, Repo}; + /// # use candle_hub::{api::Api, Repo}; /// # tokio_test::block_on(async { /// let api = Api::new().unwrap(); /// let repo = Repo::model("gpt2".to_string()); @@ -582,7 +585,7 @@ mod tests { let tmp = TempDir::new(); let api = ApiBuilder::new() .with_progress(false) - .with_cache_dir(&tmp.path) + .with_cache_dir(tmp.path.clone()) .build() .unwrap(); let repo = Repo::new("julien-c/dummy-unknown".to_string(), RepoType::Model); @@ -592,7 +595,11 @@ mod tests { assert_eq!( val, "b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32" - ) + ); + + // Make sure the file is now seeable without connection + let cache_path = api.cache.get(&repo, "config.json").unwrap(); + assert_eq!(cache_path, downloaded_path); } #[tokio::test] @@ -600,7 +607,7 @@ mod tests { let tmp = TempDir::new(); let api = ApiBuilder::new() .with_progress(false) - .with_cache_dir(&tmp.path) + .with_cache_dir(tmp.path.clone()) .build() .unwrap(); let repo = Repo::with_revision( @@ -625,7 +632,7 @@ mod tests { let tmp = TempDir::new(); let api = ApiBuilder::new() .with_progress(false) - .with_cache_dir(&tmp.path) + .with_cache_dir(tmp.path.clone()) .build() .unwrap(); let repo = Repo::with_revision( diff --git a/candle-hub/src/lib.rs b/candle-hub/src/lib.rs index f72b1a0b..ceac9718 100644 --- a/candle-hub/src/lib.rs +++ b/candle-hub/src/lib.rs @@ -8,6 +8,8 @@ //! //! At this time only a limited subset of the functionality is present, the goal is to add new //! features over time +use std::io::Write; +use std::path::PathBuf; /// The actual Api to interact with the hub. #[cfg(feature = "online")] @@ -30,6 +32,84 @@ pub enum RepoType { Space, } +/// A local struct used to fetch information from the cache folder. +pub struct Cache { + path: PathBuf, +} + +impl Cache { + /// Creates a new cache object location + pub fn new(path: PathBuf) -> Self { + Self { path } + } + + /// Creates a new cache object location + pub fn path(&self) -> &PathBuf { + &self.path + } + + /// This will get the location of the file within the cache for the remote + /// `filename`. Will return `None` if file is not already present in cache. + pub fn get(&self, repo: &Repo, filename: &str) -> Option { + let mut commit_path = self.path.clone(); + commit_path.push(repo.folder_name()); + commit_path.push("refs"); + commit_path.push(repo.revision()); + let commit_hash = std::fs::read_to_string(commit_path).ok()?; + let mut pointer_path = self.pointer_path(repo, &commit_hash); + pointer_path.push(filename); + Some(pointer_path) + } + + /// Creates a reference in the cache directory that points branches to the correct + /// commits within the blobs. + pub fn create_ref(&self, repo: &Repo, commit_hash: &str) -> Result<(), std::io::Error> { + let mut ref_path = self.path.clone(); + ref_path.push(repo.folder_name()); + ref_path.push("refs"); + ref_path.push(repo.revision()); + // Needs to be done like this because revision might contain `/` creating subfolders here. + std::fs::create_dir_all(ref_path.parent().unwrap())?; + let mut file1 = std::fs::OpenOptions::new() + .write(true) + .create(true) + .open(&ref_path)?; + file1.write_all(commit_hash.trim().as_bytes())?; + Ok(()) + } + + pub(crate) fn blob_path(&self, repo: &Repo, etag: &str) -> PathBuf { + let mut blob_path = self.path.clone(); + blob_path.push(repo.folder_name()); + blob_path.push("blobs"); + blob_path.push(etag); + blob_path + } + + pub(crate) fn pointer_path(&self, repo: &Repo, commit_hash: &str) -> PathBuf { + let mut pointer_path = self.path.clone(); + pointer_path.push(repo.folder_name()); + pointer_path.push("snapshots"); + pointer_path.push(commit_hash); + pointer_path + } +} + +impl Default for Cache { + fn default() -> Self { + let path = match std::env::var("HF_HOME") { + Ok(home) => home.into(), + Err(_) => { + let mut cache = dirs::home_dir().expect("Cache directory cannot be found"); + cache.push(".cache"); + cache.push("huggingface"); + cache + } + }; + Self::new(path) + } +} + /// The representation of a repo on the hub. pub struct Repo { repo_id: String, @@ -69,15 +149,12 @@ impl Repo { /// The normalized folder nameof the repo within the cache directory pub fn folder_name(&self) -> String { - match self.repo_type { - RepoType::Model => self.repo_id.replace('/', "--"), - RepoType::Dataset => { - format!("datasets/{}", self.repo_id.replace('/', "--")) - } - RepoType::Space => { - format!("spaces/{}", self.repo_id.replace('/', "--")) - } - } + self.repo_id.replace('/', "--") + } + + /// The revision + pub fn revision(&self) -> &str { + &self.revision } /// The actual URL part of the repo