Adding fully offline version.

This commit is contained in:
Nicolas Patry
2023-06-27 13:35:57 +02:00
parent 1a82bc50c9
commit 75e0905832
2 changed files with 128 additions and 44 deletions

View File

@ -1,4 +1,4 @@
use crate::{Repo, NAME, VERSION};
use crate::{Cache, Repo, NAME, VERSION};
use fs2::FileExt;
use indicatif::{ProgressBar, ProgressStyle};
use rand::{distributions::Alphanumeric, thread_rng, Rng};
@ -82,8 +82,8 @@ pub struct ModelInfo {
/// Helper to create [`Api`] with all the options.
pub struct ApiBuilder {
endpoint: String,
cache: Cache,
url_template: String,
cache_dir: PathBuf,
token: Option<String>,
max_files: usize,
chunk_size: usize,
@ -101,20 +101,12 @@ impl Default for ApiBuilder {
impl ApiBuilder {
/// Default api builder
/// ```
/// use hf_hub::api::ApiBuilder;
/// use candle_hub::api::ApiBuilder;
/// let api = ApiBuilder::new().build().unwrap();
/// ```
pub fn new() -> Self {
let cache_dir = match std::env::var("HF_HOME") {
Ok(home) => home.into(),
Err(_) => {
let mut cache = dirs::home_dir().expect("Cache directory cannot be found");
cache.push(".cache");
cache.push("huggingface");
cache
}
};
let mut token_filename = cache_dir.clone();
let cache = Cache::default();
let mut token_filename = cache.path().clone();
token_filename.push(".token");
let token = match std::fs::read_to_string(token_filename) {
Ok(token_content) => {
@ -133,7 +125,7 @@ impl ApiBuilder {
Self {
endpoint: "https://huggingface.co".to_string(),
url_template: "{endpoint}/{repo_id}/resolve/{revision}/{filename}".to_string(),
cache_dir,
cache,
token,
max_files: 100,
chunk_size: 10_000_000,
@ -150,8 +142,8 @@ impl ApiBuilder {
}
/// Changes the location of the cache directory. Defaults is `~/.cache/huggingface/`.
pub fn with_cache_dir(mut self, cache_dir: &Path) -> Self {
self.cache_dir = cache_dir.to_path_buf();
pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self {
self.cache = Cache::new(cache_dir);
self
}
@ -179,7 +171,7 @@ impl ApiBuilder {
Ok(Api {
endpoint: self.endpoint,
url_template: self.url_template,
cache_dir: self.cache_dir,
cache: self.cache,
client,
no_redirect_client,
@ -205,7 +197,7 @@ struct Metadata {
pub struct Api {
endpoint: String,
url_template: String,
cache_dir: PathBuf,
cache: Cache,
client: Client,
no_redirect_client: Client,
max_files: usize,
@ -293,7 +285,7 @@ impl Api {
/// Get the fully qualified URL of the remote filename
/// ```
/// # use hf_hub::{api::Api, Repo};
/// # use candle_hub::{api::Api, Repo};
/// let api = Api::new().unwrap();
/// let repo = Repo::model("gpt2".to_string());
/// let url = api.url(&repo, "model.safetensors");
@ -459,12 +451,29 @@ impl Api {
Ok(())
}
/// This will attempt the fetch the file locally first, then [`Api.download`]
/// if the file is not present.
/// ```no_run
/// # use candle_hub::{api::ApiBuilder, Repo};
/// # tokio_test::block_on(async {
/// let api = ApiBuilder::new().build().unwrap();
/// let repo = Repo::model("gpt2".to_string());
/// let local_filename = api.get(&repo, "model.safetensors").await.unwrap();
/// # })
pub async fn get(&self, repo: &Repo, filename: &str) -> Result<PathBuf, ApiError> {
if let Some(path) = self.cache.get(repo, filename) {
Ok(path)
} else {
self.download(repo, filename).await
}
}
/// Downloads a remote file (if not already present) into the cache directory
/// to be used locally.
/// This functions require internet access to verify if new versions of the file
/// exist, even if a file is already on disk at location.
/// ```no_run
/// # use hf_hub::{api::ApiBuilder, Repo};
/// # use candle_hub::{api::ApiBuilder, Repo};
/// # tokio_test::block_on(async {
/// let api = ApiBuilder::new().build().unwrap();
/// let repo = Repo::model("gpt2".to_string());
@ -475,13 +484,8 @@ impl Api {
let url = self.url(repo, filename);
let metadata = self.metadata(&url).await?;
let mut folder = self.cache_dir.clone();
folder.push(repo.folder_name());
let mut blob_path = folder.clone();
blob_path.push("blobs");
std::fs::create_dir_all(&blob_path).ok();
blob_path.push(&metadata.etag);
let blob_path = self.cache.blob_path(repo, &metadata.etag);
std::fs::create_dir_all(blob_path.parent().unwrap())?;
let file1 = std::fs::OpenOptions::new()
.read(true)
@ -515,20 +519,19 @@ impl Api {
.await?;
std::fs::copy(tmp_filename, &blob_path)?;
let mut pointer_path = folder.clone();
pointer_path.push("snapshots");
pointer_path.push(&metadata.commit_hash);
let mut pointer_path = self.cache.pointer_path(repo, &metadata.commit_hash);
pointer_path.push(filename);
std::fs::create_dir_all(pointer_path.parent().unwrap()).ok();
symlink_or_rename(&blob_path, &pointer_path)?;
self.cache.create_ref(repo, &metadata.commit_hash)?;
Ok(pointer_path)
}
/// Get information about the Repo
/// ```
/// # use hf_hub::{api::Api, Repo};
/// # use candle_hub::{api::Api, Repo};
/// # tokio_test::block_on(async {
/// let api = Api::new().unwrap();
/// let repo = Repo::model("gpt2".to_string());
@ -582,7 +585,7 @@ mod tests {
let tmp = TempDir::new();
let api = ApiBuilder::new()
.with_progress(false)
.with_cache_dir(&tmp.path)
.with_cache_dir(tmp.path.clone())
.build()
.unwrap();
let repo = Repo::new("julien-c/dummy-unknown".to_string(), RepoType::Model);
@ -592,7 +595,11 @@ mod tests {
assert_eq!(
val,
"b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32"
)
);
// Make sure the file is now seeable without connection
let cache_path = api.cache.get(&repo, "config.json").unwrap();
assert_eq!(cache_path, downloaded_path);
}
#[tokio::test]
@ -600,7 +607,7 @@ mod tests {
let tmp = TempDir::new();
let api = ApiBuilder::new()
.with_progress(false)
.with_cache_dir(&tmp.path)
.with_cache_dir(tmp.path.clone())
.build()
.unwrap();
let repo = Repo::with_revision(
@ -625,7 +632,7 @@ mod tests {
let tmp = TempDir::new();
let api = ApiBuilder::new()
.with_progress(false)
.with_cache_dir(&tmp.path)
.with_cache_dir(tmp.path.clone())
.build()
.unwrap();
let repo = Repo::with_revision(

View File

@ -8,6 +8,8 @@
//!
//! At this time only a limited subset of the functionality is present, the goal is to add new
//! features over time
use std::io::Write;
use std::path::PathBuf;
/// The actual Api to interact with the hub.
#[cfg(feature = "online")]
@ -30,6 +32,84 @@ pub enum RepoType {
Space,
}
/// A local struct used to fetch information from the cache folder.
pub struct Cache {
path: PathBuf,
}
impl Cache {
/// Creates a new cache object location
pub fn new(path: PathBuf) -> Self {
Self { path }
}
/// Creates a new cache object location
pub fn path(&self) -> &PathBuf {
&self.path
}
/// This will get the location of the file within the cache for the remote
/// `filename`. Will return `None` if file is not already present in cache.
pub fn get(&self, repo: &Repo, filename: &str) -> Option<PathBuf> {
let mut commit_path = self.path.clone();
commit_path.push(repo.folder_name());
commit_path.push("refs");
commit_path.push(repo.revision());
let commit_hash = std::fs::read_to_string(commit_path).ok()?;
let mut pointer_path = self.pointer_path(repo, &commit_hash);
pointer_path.push(filename);
Some(pointer_path)
}
/// Creates a reference in the cache directory that points branches to the correct
/// commits within the blobs.
pub fn create_ref(&self, repo: &Repo, commit_hash: &str) -> Result<(), std::io::Error> {
let mut ref_path = self.path.clone();
ref_path.push(repo.folder_name());
ref_path.push("refs");
ref_path.push(repo.revision());
// Needs to be done like this because revision might contain `/` creating subfolders here.
std::fs::create_dir_all(ref_path.parent().unwrap())?;
let mut file1 = std::fs::OpenOptions::new()
.write(true)
.create(true)
.open(&ref_path)?;
file1.write_all(commit_hash.trim().as_bytes())?;
Ok(())
}
pub(crate) fn blob_path(&self, repo: &Repo, etag: &str) -> PathBuf {
let mut blob_path = self.path.clone();
blob_path.push(repo.folder_name());
blob_path.push("blobs");
blob_path.push(etag);
blob_path
}
pub(crate) fn pointer_path(&self, repo: &Repo, commit_hash: &str) -> PathBuf {
let mut pointer_path = self.path.clone();
pointer_path.push(repo.folder_name());
pointer_path.push("snapshots");
pointer_path.push(commit_hash);
pointer_path
}
}
impl Default for Cache {
fn default() -> Self {
let path = match std::env::var("HF_HOME") {
Ok(home) => home.into(),
Err(_) => {
let mut cache = dirs::home_dir().expect("Cache directory cannot be found");
cache.push(".cache");
cache.push("huggingface");
cache
}
};
Self::new(path)
}
}
/// The representation of a repo on the hub.
pub struct Repo {
repo_id: String,
@ -69,15 +149,12 @@ impl Repo {
/// The normalized folder nameof the repo within the cache directory
pub fn folder_name(&self) -> String {
match self.repo_type {
RepoType::Model => self.repo_id.replace('/', "--"),
RepoType::Dataset => {
format!("datasets/{}", self.repo_id.replace('/', "--"))
}
RepoType::Space => {
format!("spaces/{}", self.repo_id.replace('/', "--"))
}
}
self.repo_id.replace('/', "--")
}
/// The revision
pub fn revision(&self) -> &str {
&self.revision
}
/// The actual URL part of the repo