mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Adding fully offline version.
This commit is contained in:
@ -1,4 +1,4 @@
|
||||
use crate::{Repo, NAME, VERSION};
|
||||
use crate::{Cache, Repo, NAME, VERSION};
|
||||
use fs2::FileExt;
|
||||
use indicatif::{ProgressBar, ProgressStyle};
|
||||
use rand::{distributions::Alphanumeric, thread_rng, Rng};
|
||||
@ -82,8 +82,8 @@ pub struct ModelInfo {
|
||||
/// Helper to create [`Api`] with all the options.
|
||||
pub struct ApiBuilder {
|
||||
endpoint: String,
|
||||
cache: Cache,
|
||||
url_template: String,
|
||||
cache_dir: PathBuf,
|
||||
token: Option<String>,
|
||||
max_files: usize,
|
||||
chunk_size: usize,
|
||||
@ -101,20 +101,12 @@ impl Default for ApiBuilder {
|
||||
impl ApiBuilder {
|
||||
/// Default api builder
|
||||
/// ```
|
||||
/// use hf_hub::api::ApiBuilder;
|
||||
/// use candle_hub::api::ApiBuilder;
|
||||
/// let api = ApiBuilder::new().build().unwrap();
|
||||
/// ```
|
||||
pub fn new() -> Self {
|
||||
let cache_dir = match std::env::var("HF_HOME") {
|
||||
Ok(home) => home.into(),
|
||||
Err(_) => {
|
||||
let mut cache = dirs::home_dir().expect("Cache directory cannot be found");
|
||||
cache.push(".cache");
|
||||
cache.push("huggingface");
|
||||
cache
|
||||
}
|
||||
};
|
||||
let mut token_filename = cache_dir.clone();
|
||||
let cache = Cache::default();
|
||||
let mut token_filename = cache.path().clone();
|
||||
token_filename.push(".token");
|
||||
let token = match std::fs::read_to_string(token_filename) {
|
||||
Ok(token_content) => {
|
||||
@ -133,7 +125,7 @@ impl ApiBuilder {
|
||||
Self {
|
||||
endpoint: "https://huggingface.co".to_string(),
|
||||
url_template: "{endpoint}/{repo_id}/resolve/{revision}/{filename}".to_string(),
|
||||
cache_dir,
|
||||
cache,
|
||||
token,
|
||||
max_files: 100,
|
||||
chunk_size: 10_000_000,
|
||||
@ -150,8 +142,8 @@ impl ApiBuilder {
|
||||
}
|
||||
|
||||
/// Changes the location of the cache directory. Defaults is `~/.cache/huggingface/`.
|
||||
pub fn with_cache_dir(mut self, cache_dir: &Path) -> Self {
|
||||
self.cache_dir = cache_dir.to_path_buf();
|
||||
pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self {
|
||||
self.cache = Cache::new(cache_dir);
|
||||
self
|
||||
}
|
||||
|
||||
@ -179,7 +171,7 @@ impl ApiBuilder {
|
||||
Ok(Api {
|
||||
endpoint: self.endpoint,
|
||||
url_template: self.url_template,
|
||||
cache_dir: self.cache_dir,
|
||||
cache: self.cache,
|
||||
client,
|
||||
|
||||
no_redirect_client,
|
||||
@ -205,7 +197,7 @@ struct Metadata {
|
||||
pub struct Api {
|
||||
endpoint: String,
|
||||
url_template: String,
|
||||
cache_dir: PathBuf,
|
||||
cache: Cache,
|
||||
client: Client,
|
||||
no_redirect_client: Client,
|
||||
max_files: usize,
|
||||
@ -293,7 +285,7 @@ impl Api {
|
||||
|
||||
/// Get the fully qualified URL of the remote filename
|
||||
/// ```
|
||||
/// # use hf_hub::{api::Api, Repo};
|
||||
/// # use candle_hub::{api::Api, Repo};
|
||||
/// let api = Api::new().unwrap();
|
||||
/// let repo = Repo::model("gpt2".to_string());
|
||||
/// let url = api.url(&repo, "model.safetensors");
|
||||
@ -459,12 +451,29 @@ impl Api {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// This will attempt the fetch the file locally first, then [`Api.download`]
|
||||
/// if the file is not present.
|
||||
/// ```no_run
|
||||
/// # use candle_hub::{api::ApiBuilder, Repo};
|
||||
/// # tokio_test::block_on(async {
|
||||
/// let api = ApiBuilder::new().build().unwrap();
|
||||
/// let repo = Repo::model("gpt2".to_string());
|
||||
/// let local_filename = api.get(&repo, "model.safetensors").await.unwrap();
|
||||
/// # })
|
||||
pub async fn get(&self, repo: &Repo, filename: &str) -> Result<PathBuf, ApiError> {
|
||||
if let Some(path) = self.cache.get(repo, filename) {
|
||||
Ok(path)
|
||||
} else {
|
||||
self.download(repo, filename).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Downloads a remote file (if not already present) into the cache directory
|
||||
/// to be used locally.
|
||||
/// This functions require internet access to verify if new versions of the file
|
||||
/// exist, even if a file is already on disk at location.
|
||||
/// ```no_run
|
||||
/// # use hf_hub::{api::ApiBuilder, Repo};
|
||||
/// # use candle_hub::{api::ApiBuilder, Repo};
|
||||
/// # tokio_test::block_on(async {
|
||||
/// let api = ApiBuilder::new().build().unwrap();
|
||||
/// let repo = Repo::model("gpt2".to_string());
|
||||
@ -475,13 +484,8 @@ impl Api {
|
||||
let url = self.url(repo, filename);
|
||||
let metadata = self.metadata(&url).await?;
|
||||
|
||||
let mut folder = self.cache_dir.clone();
|
||||
folder.push(repo.folder_name());
|
||||
|
||||
let mut blob_path = folder.clone();
|
||||
blob_path.push("blobs");
|
||||
std::fs::create_dir_all(&blob_path).ok();
|
||||
blob_path.push(&metadata.etag);
|
||||
let blob_path = self.cache.blob_path(repo, &metadata.etag);
|
||||
std::fs::create_dir_all(blob_path.parent().unwrap())?;
|
||||
|
||||
let file1 = std::fs::OpenOptions::new()
|
||||
.read(true)
|
||||
@ -515,20 +519,19 @@ impl Api {
|
||||
.await?;
|
||||
std::fs::copy(tmp_filename, &blob_path)?;
|
||||
|
||||
let mut pointer_path = folder.clone();
|
||||
pointer_path.push("snapshots");
|
||||
pointer_path.push(&metadata.commit_hash);
|
||||
let mut pointer_path = self.cache.pointer_path(repo, &metadata.commit_hash);
|
||||
pointer_path.push(filename);
|
||||
std::fs::create_dir_all(pointer_path.parent().unwrap()).ok();
|
||||
|
||||
symlink_or_rename(&blob_path, &pointer_path)?;
|
||||
self.cache.create_ref(repo, &metadata.commit_hash)?;
|
||||
|
||||
Ok(pointer_path)
|
||||
}
|
||||
|
||||
/// Get information about the Repo
|
||||
/// ```
|
||||
/// # use hf_hub::{api::Api, Repo};
|
||||
/// # use candle_hub::{api::Api, Repo};
|
||||
/// # tokio_test::block_on(async {
|
||||
/// let api = Api::new().unwrap();
|
||||
/// let repo = Repo::model("gpt2".to_string());
|
||||
@ -582,7 +585,7 @@ mod tests {
|
||||
let tmp = TempDir::new();
|
||||
let api = ApiBuilder::new()
|
||||
.with_progress(false)
|
||||
.with_cache_dir(&tmp.path)
|
||||
.with_cache_dir(tmp.path.clone())
|
||||
.build()
|
||||
.unwrap();
|
||||
let repo = Repo::new("julien-c/dummy-unknown".to_string(), RepoType::Model);
|
||||
@ -592,7 +595,11 @@ mod tests {
|
||||
assert_eq!(
|
||||
val,
|
||||
"b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32"
|
||||
)
|
||||
);
|
||||
|
||||
// Make sure the file is now seeable without connection
|
||||
let cache_path = api.cache.get(&repo, "config.json").unwrap();
|
||||
assert_eq!(cache_path, downloaded_path);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@ -600,7 +607,7 @@ mod tests {
|
||||
let tmp = TempDir::new();
|
||||
let api = ApiBuilder::new()
|
||||
.with_progress(false)
|
||||
.with_cache_dir(&tmp.path)
|
||||
.with_cache_dir(tmp.path.clone())
|
||||
.build()
|
||||
.unwrap();
|
||||
let repo = Repo::with_revision(
|
||||
@ -625,7 +632,7 @@ mod tests {
|
||||
let tmp = TempDir::new();
|
||||
let api = ApiBuilder::new()
|
||||
.with_progress(false)
|
||||
.with_cache_dir(&tmp.path)
|
||||
.with_cache_dir(tmp.path.clone())
|
||||
.build()
|
||||
.unwrap();
|
||||
let repo = Repo::with_revision(
|
||||
|
@ -8,6 +8,8 @@
|
||||
//!
|
||||
//! At this time only a limited subset of the functionality is present, the goal is to add new
|
||||
//! features over time
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// The actual Api to interact with the hub.
|
||||
#[cfg(feature = "online")]
|
||||
@ -30,6 +32,84 @@ pub enum RepoType {
|
||||
Space,
|
||||
}
|
||||
|
||||
/// A local struct used to fetch information from the cache folder.
|
||||
pub struct Cache {
|
||||
path: PathBuf,
|
||||
}
|
||||
|
||||
impl Cache {
|
||||
/// Creates a new cache object location
|
||||
pub fn new(path: PathBuf) -> Self {
|
||||
Self { path }
|
||||
}
|
||||
|
||||
/// Creates a new cache object location
|
||||
pub fn path(&self) -> &PathBuf {
|
||||
&self.path
|
||||
}
|
||||
|
||||
/// This will get the location of the file within the cache for the remote
|
||||
/// `filename`. Will return `None` if file is not already present in cache.
|
||||
pub fn get(&self, repo: &Repo, filename: &str) -> Option<PathBuf> {
|
||||
let mut commit_path = self.path.clone();
|
||||
commit_path.push(repo.folder_name());
|
||||
commit_path.push("refs");
|
||||
commit_path.push(repo.revision());
|
||||
let commit_hash = std::fs::read_to_string(commit_path).ok()?;
|
||||
let mut pointer_path = self.pointer_path(repo, &commit_hash);
|
||||
pointer_path.push(filename);
|
||||
Some(pointer_path)
|
||||
}
|
||||
|
||||
/// Creates a reference in the cache directory that points branches to the correct
|
||||
/// commits within the blobs.
|
||||
pub fn create_ref(&self, repo: &Repo, commit_hash: &str) -> Result<(), std::io::Error> {
|
||||
let mut ref_path = self.path.clone();
|
||||
ref_path.push(repo.folder_name());
|
||||
ref_path.push("refs");
|
||||
ref_path.push(repo.revision());
|
||||
// Needs to be done like this because revision might contain `/` creating subfolders here.
|
||||
std::fs::create_dir_all(ref_path.parent().unwrap())?;
|
||||
let mut file1 = std::fs::OpenOptions::new()
|
||||
.write(true)
|
||||
.create(true)
|
||||
.open(&ref_path)?;
|
||||
file1.write_all(commit_hash.trim().as_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn blob_path(&self, repo: &Repo, etag: &str) -> PathBuf {
|
||||
let mut blob_path = self.path.clone();
|
||||
blob_path.push(repo.folder_name());
|
||||
blob_path.push("blobs");
|
||||
blob_path.push(etag);
|
||||
blob_path
|
||||
}
|
||||
|
||||
pub(crate) fn pointer_path(&self, repo: &Repo, commit_hash: &str) -> PathBuf {
|
||||
let mut pointer_path = self.path.clone();
|
||||
pointer_path.push(repo.folder_name());
|
||||
pointer_path.push("snapshots");
|
||||
pointer_path.push(commit_hash);
|
||||
pointer_path
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Cache {
|
||||
fn default() -> Self {
|
||||
let path = match std::env::var("HF_HOME") {
|
||||
Ok(home) => home.into(),
|
||||
Err(_) => {
|
||||
let mut cache = dirs::home_dir().expect("Cache directory cannot be found");
|
||||
cache.push(".cache");
|
||||
cache.push("huggingface");
|
||||
cache
|
||||
}
|
||||
};
|
||||
Self::new(path)
|
||||
}
|
||||
}
|
||||
|
||||
/// The representation of a repo on the hub.
|
||||
pub struct Repo {
|
||||
repo_id: String,
|
||||
@ -69,15 +149,12 @@ impl Repo {
|
||||
|
||||
/// The normalized folder nameof the repo within the cache directory
|
||||
pub fn folder_name(&self) -> String {
|
||||
match self.repo_type {
|
||||
RepoType::Model => self.repo_id.replace('/', "--"),
|
||||
RepoType::Dataset => {
|
||||
format!("datasets/{}", self.repo_id.replace('/', "--"))
|
||||
}
|
||||
RepoType::Space => {
|
||||
format!("spaces/{}", self.repo_id.replace('/', "--"))
|
||||
}
|
||||
}
|
||||
self.repo_id.replace('/', "--")
|
||||
}
|
||||
|
||||
/// The revision
|
||||
pub fn revision(&self) -> &str {
|
||||
&self.revision
|
||||
}
|
||||
|
||||
/// The actual URL part of the repo
|
||||
|
Reference in New Issue
Block a user