mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
members = [
|
members = [
|
||||||
"candle-core",
|
"candle-core",
|
||||||
|
"candle-hub",
|
||||||
"candle-kernels",
|
"candle-kernels",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
2
candle-hub/.gitignore
vendored
Normal file
2
candle-hub/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
/target
|
||||||
|
/Cargo.lock
|
30
candle-hub/Cargo.toml
Normal file
30
candle-hub/Cargo.toml
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
[package]
|
||||||
|
name = "candle-hub"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
dirs = "5.0.1"
|
||||||
|
fs2 = "0.4.3"
|
||||||
|
rand = "0.8.5"
|
||||||
|
thiserror = "1.0.40"
|
||||||
|
futures = { version = "0.3.28", optional = true }
|
||||||
|
reqwest = { version = "0.11.18", optional = true, features = ["json"] }
|
||||||
|
tokio = { version = "1.28.2", features = ["fs"], optional = true }
|
||||||
|
serde = { version = "1.0.164", features = ["derive"], optional = true }
|
||||||
|
serde_json = { version = "1.0.97", optional = true }
|
||||||
|
indicatif = { version = "0.17.5", optional = true }
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
rand = "0.8.5"
|
||||||
|
sha256 = "1.1.4"
|
||||||
|
tokio = { version = "1.28.2", features = ["macros"] }
|
||||||
|
tokio-test = "0.4.2"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = ["online"]
|
||||||
|
online = ["reqwest", "tokio", "futures", "serde", "serde_json", "indicatif"]
|
||||||
|
|
||||||
|
|
706
candle-hub/src/api.rs
Normal file
706
candle-hub/src/api.rs
Normal file
@ -0,0 +1,706 @@
|
|||||||
|
use crate::{Cache, Repo};
|
||||||
|
use fs2::FileExt;
|
||||||
|
use indicatif::{ProgressBar, ProgressStyle};
|
||||||
|
use rand::{distributions::Alphanumeric, thread_rng, Rng};
|
||||||
|
use reqwest::{
|
||||||
|
header::{
|
||||||
|
HeaderMap, HeaderName, HeaderValue, InvalidHeaderValue, ToStrError, AUTHORIZATION,
|
||||||
|
CONTENT_RANGE, LOCATION, RANGE, USER_AGENT,
|
||||||
|
},
|
||||||
|
redirect::Policy,
|
||||||
|
Client, Error as ReqwestError,
|
||||||
|
};
|
||||||
|
use serde::Deserialize;
|
||||||
|
use std::num::ParseIntError;
|
||||||
|
use std::path::{Component, Path, PathBuf};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use thiserror::Error;
|
||||||
|
use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom};
|
||||||
|
use tokio::sync::{AcquireError, Semaphore, TryAcquireError};
|
||||||
|
|
||||||
|
/// Current version (used in user-agent)
|
||||||
|
const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||||
|
/// Current name (used in user-agent)
|
||||||
|
const NAME: &str = env!("CARGO_PKG_NAME");
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
/// All errors the API can throw
|
||||||
|
pub enum ApiError {
|
||||||
|
/// Api expects certain header to be present in the results to derive some information
|
||||||
|
#[error("Header {0} is missing")]
|
||||||
|
MissingHeader(HeaderName),
|
||||||
|
|
||||||
|
/// The header exists, but the value is not conform to what the Api expects.
|
||||||
|
#[error("Header {0} is invalid")]
|
||||||
|
InvalidHeader(HeaderName),
|
||||||
|
|
||||||
|
/// The value cannot be used as a header during request header construction
|
||||||
|
#[error("Invalid header value {0}")]
|
||||||
|
InvalidHeaderValue(#[from] InvalidHeaderValue),
|
||||||
|
|
||||||
|
/// The header value is not valid utf-8
|
||||||
|
#[error("header value is not a string")]
|
||||||
|
ToStr(#[from] ToStrError),
|
||||||
|
|
||||||
|
/// Error in the request
|
||||||
|
#[error("request error: {0}")]
|
||||||
|
RequestError(#[from] ReqwestError),
|
||||||
|
|
||||||
|
/// Error parsing some range value
|
||||||
|
#[error("Cannot parse int")]
|
||||||
|
ParseIntError(#[from] ParseIntError),
|
||||||
|
|
||||||
|
/// I/O Error
|
||||||
|
#[error("I/O error {0}")]
|
||||||
|
IoError(#[from] std::io::Error),
|
||||||
|
|
||||||
|
/// We tried to download chunk too many times
|
||||||
|
#[error("Too many retries: {0}")]
|
||||||
|
TooManyRetries(Box<ApiError>),
|
||||||
|
|
||||||
|
/// Semaphore cannot be acquired
|
||||||
|
#[error("Try acquire: {0}")]
|
||||||
|
TryAcquireError(#[from] TryAcquireError),
|
||||||
|
|
||||||
|
/// Semaphore cannot be acquired
|
||||||
|
#[error("Acquire: {0}")]
|
||||||
|
AcquireError(#[from] AcquireError),
|
||||||
|
// /// Semaphore cannot be acquired
|
||||||
|
// #[error("Invalid Response: {0:?}")]
|
||||||
|
// InvalidResponse(Response),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Siblings are simplified file descriptions of remote files on the hub
|
||||||
|
#[derive(Debug, Clone, Deserialize, PartialEq)]
|
||||||
|
pub struct Siblings {
|
||||||
|
/// The path within the repo.
|
||||||
|
pub rfilename: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The description of the repo given by the hub
|
||||||
|
#[derive(Debug, Clone, Deserialize, PartialEq)]
|
||||||
|
pub struct ModelInfo {
|
||||||
|
/// See [`Siblings`]
|
||||||
|
pub siblings: Vec<Siblings>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper to create [`Api`] with all the options.
|
||||||
|
pub struct ApiBuilder {
|
||||||
|
endpoint: String,
|
||||||
|
cache: Cache,
|
||||||
|
url_template: String,
|
||||||
|
token: Option<String>,
|
||||||
|
max_files: usize,
|
||||||
|
chunk_size: usize,
|
||||||
|
parallel_failures: usize,
|
||||||
|
max_retries: usize,
|
||||||
|
progress: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ApiBuilder {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ApiBuilder {
|
||||||
|
/// Default api builder
|
||||||
|
/// ```
|
||||||
|
/// use candle_hub::api::ApiBuilder;
|
||||||
|
/// let api = ApiBuilder::new().build().unwrap();
|
||||||
|
/// ```
|
||||||
|
pub fn new() -> Self {
|
||||||
|
let cache = Cache::default();
|
||||||
|
let mut token_filename = cache.path().clone();
|
||||||
|
token_filename.push(".token");
|
||||||
|
let token = match std::fs::read_to_string(token_filename) {
|
||||||
|
Ok(token_content) => {
|
||||||
|
let token_content = token_content.trim();
|
||||||
|
if !token_content.is_empty() {
|
||||||
|
Some(token_content.to_string())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(_) => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let progress = true;
|
||||||
|
|
||||||
|
Self {
|
||||||
|
endpoint: "https://huggingface.co".to_string(),
|
||||||
|
url_template: "{endpoint}/{repo_id}/resolve/{revision}/{filename}".to_string(),
|
||||||
|
cache,
|
||||||
|
token,
|
||||||
|
max_files: 100,
|
||||||
|
chunk_size: 10_000_000,
|
||||||
|
parallel_failures: 0,
|
||||||
|
max_retries: 0,
|
||||||
|
progress,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Wether to show a progressbar
|
||||||
|
pub fn with_progress(mut self, progress: bool) -> Self {
|
||||||
|
self.progress = progress;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Changes the location of the cache directory. Defaults is `~/.cache/huggingface/`.
|
||||||
|
pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self {
|
||||||
|
self.cache = Cache::new(cache_dir);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_headers(&self) -> Result<HeaderMap, ApiError> {
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
let user_agent = format!("unkown/None; {NAME}/{VERSION}; rust/unknown");
|
||||||
|
headers.insert(USER_AGENT, HeaderValue::from_str(&user_agent)?);
|
||||||
|
if let Some(token) = &self.token {
|
||||||
|
headers.insert(
|
||||||
|
AUTHORIZATION,
|
||||||
|
HeaderValue::from_str(&format!("Bearer {token}"))?,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(headers)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Consumes the builder and buids the final [`Api`]
|
||||||
|
pub fn build(self) -> Result<Api, ApiError> {
|
||||||
|
let headers = self.build_headers()?;
|
||||||
|
let client = Client::builder().default_headers(headers.clone()).build()?;
|
||||||
|
let no_redirect_client = Client::builder()
|
||||||
|
.redirect(Policy::none())
|
||||||
|
.default_headers(headers)
|
||||||
|
.build()?;
|
||||||
|
Ok(Api {
|
||||||
|
endpoint: self.endpoint,
|
||||||
|
url_template: self.url_template,
|
||||||
|
cache: self.cache,
|
||||||
|
client,
|
||||||
|
|
||||||
|
no_redirect_client,
|
||||||
|
max_files: self.max_files,
|
||||||
|
chunk_size: self.chunk_size,
|
||||||
|
parallel_failures: self.parallel_failures,
|
||||||
|
max_retries: self.max_retries,
|
||||||
|
progress: self.progress,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct Metadata {
|
||||||
|
commit_hash: String,
|
||||||
|
etag: String,
|
||||||
|
size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The actual Api used to interacto with the hub.
|
||||||
|
/// You can inspect repos with [`Api::info`]
|
||||||
|
/// or download files with [`Api::download`]
|
||||||
|
pub struct Api {
|
||||||
|
endpoint: String,
|
||||||
|
url_template: String,
|
||||||
|
cache: Cache,
|
||||||
|
client: Client,
|
||||||
|
no_redirect_client: Client,
|
||||||
|
max_files: usize,
|
||||||
|
chunk_size: usize,
|
||||||
|
parallel_failures: usize,
|
||||||
|
max_retries: usize,
|
||||||
|
progress: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn temp_filename() -> PathBuf {
|
||||||
|
let s: String = rand::thread_rng()
|
||||||
|
.sample_iter(&Alphanumeric)
|
||||||
|
.take(7)
|
||||||
|
.map(char::from)
|
||||||
|
.collect();
|
||||||
|
let mut path = std::env::temp_dir();
|
||||||
|
path.push(s);
|
||||||
|
path
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_relative(src: &Path, dst: &Path) -> PathBuf {
|
||||||
|
let path = src;
|
||||||
|
let base = dst;
|
||||||
|
|
||||||
|
if path.is_absolute() != base.is_absolute() {
|
||||||
|
panic!("This function is made to look at absolute paths only");
|
||||||
|
}
|
||||||
|
let mut ita = path.components();
|
||||||
|
let mut itb = base.components();
|
||||||
|
|
||||||
|
loop {
|
||||||
|
match (ita.next(), itb.next()) {
|
||||||
|
(Some(a), Some(b)) if a == b => (),
|
||||||
|
(some_a, _) => {
|
||||||
|
// Ignoring b, because 1 component is the filename
|
||||||
|
// for which we don't need to go back up for relative
|
||||||
|
// filename to work.
|
||||||
|
let mut new_path = PathBuf::new();
|
||||||
|
for _ in itb {
|
||||||
|
new_path.push(Component::ParentDir);
|
||||||
|
}
|
||||||
|
if let Some(a) = some_a {
|
||||||
|
new_path.push(a);
|
||||||
|
for comp in ita {
|
||||||
|
new_path.push(comp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return new_path;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn symlink_or_rename(src: &Path, dst: &Path) -> Result<(), std::io::Error> {
|
||||||
|
if dst.exists() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let src = make_relative(src, dst);
|
||||||
|
#[cfg(target_os = "windows")]
|
||||||
|
std::os::windows::fs::symlink_file(src, dst)?;
|
||||||
|
|
||||||
|
#[cfg(target_family = "unix")]
|
||||||
|
std::os::unix::fs::symlink(src, dst)?;
|
||||||
|
|
||||||
|
#[cfg(not(any(target_family = "unix", target_os = "windows")))]
|
||||||
|
std::fs::rename(src, dst)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn jitter() -> usize {
|
||||||
|
thread_rng().gen_range(0..=500)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn exponential_backoff(base_wait_time: usize, n: usize, max: usize) -> usize {
|
||||||
|
(base_wait_time + n.pow(2) + jitter()).min(max)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Api {
|
||||||
|
/// Creates a default Api, for Api options See [`ApiBuilder`]
|
||||||
|
pub fn new() -> Result<Self, ApiError> {
|
||||||
|
ApiBuilder::new().build()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the fully qualified URL of the remote filename
|
||||||
|
/// ```
|
||||||
|
/// # use candle_hub::{api::Api, Repo};
|
||||||
|
/// let api = Api::new().unwrap();
|
||||||
|
/// let repo = Repo::model("gpt2".to_string());
|
||||||
|
/// let url = api.url(&repo, "model.safetensors");
|
||||||
|
/// assert_eq!(url, "https://huggingface.co/gpt2/resolve/main/model.safetensors");
|
||||||
|
/// ```
|
||||||
|
pub fn url(&self, repo: &Repo, filename: &str) -> String {
|
||||||
|
let endpoint = &self.endpoint;
|
||||||
|
let revision = &repo.url_revision();
|
||||||
|
self.url_template
|
||||||
|
.replace("{endpoint}", endpoint)
|
||||||
|
.replace("{repo_id}", &repo.url())
|
||||||
|
.replace("{revision}", revision)
|
||||||
|
.replace("{filename}", filename)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the underlying api client
|
||||||
|
/// Allows for lower level access
|
||||||
|
pub fn client(&self) -> &Client {
|
||||||
|
&self.client
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn metadata(&self, url: &str) -> Result<Metadata, ApiError> {
|
||||||
|
let response = self
|
||||||
|
.no_redirect_client
|
||||||
|
.get(url)
|
||||||
|
.header(RANGE, "bytes=0-0")
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
let response = response.error_for_status()?;
|
||||||
|
let headers = response.headers();
|
||||||
|
let header_commit = HeaderName::from_static("x-repo-commit");
|
||||||
|
let header_linked_etag = HeaderName::from_static("x-linked-etag");
|
||||||
|
let header_etag = HeaderName::from_static("etag");
|
||||||
|
|
||||||
|
let etag = match headers.get(&header_linked_etag) {
|
||||||
|
Some(etag) => etag,
|
||||||
|
None => headers
|
||||||
|
.get(&header_etag)
|
||||||
|
.ok_or(ApiError::MissingHeader(header_etag))?,
|
||||||
|
};
|
||||||
|
// Cleaning extra quotes
|
||||||
|
let etag = etag.to_str()?.to_string().replace('"', "");
|
||||||
|
let commit_hash = headers
|
||||||
|
.get(&header_commit)
|
||||||
|
.ok_or(ApiError::MissingHeader(header_commit))?
|
||||||
|
.to_str()?
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
// The response was redirected o S3 most likely which will
|
||||||
|
// know about the size of the file
|
||||||
|
let response = if response.status().is_redirection() {
|
||||||
|
self.client
|
||||||
|
.get(headers.get(LOCATION).unwrap().to_str()?.to_string())
|
||||||
|
.header(RANGE, "bytes=0-0")
|
||||||
|
.send()
|
||||||
|
.await?
|
||||||
|
} else {
|
||||||
|
response
|
||||||
|
};
|
||||||
|
let headers = response.headers();
|
||||||
|
let content_range = headers
|
||||||
|
.get(CONTENT_RANGE)
|
||||||
|
.ok_or(ApiError::MissingHeader(CONTENT_RANGE))?
|
||||||
|
.to_str()?;
|
||||||
|
|
||||||
|
let size = content_range
|
||||||
|
.split('/')
|
||||||
|
.last()
|
||||||
|
.ok_or(ApiError::InvalidHeader(CONTENT_RANGE))?
|
||||||
|
.parse()?;
|
||||||
|
Ok(Metadata {
|
||||||
|
commit_hash,
|
||||||
|
etag,
|
||||||
|
size,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn download_tempfile(
|
||||||
|
&self,
|
||||||
|
url: &str,
|
||||||
|
length: usize,
|
||||||
|
progressbar: Option<ProgressBar>,
|
||||||
|
) -> Result<PathBuf, ApiError> {
|
||||||
|
let mut handles = vec![];
|
||||||
|
let semaphore = Arc::new(Semaphore::new(self.max_files));
|
||||||
|
let parallel_failures_semaphore = Arc::new(Semaphore::new(self.parallel_failures));
|
||||||
|
let filename = temp_filename();
|
||||||
|
|
||||||
|
let chunk_size = self.chunk_size;
|
||||||
|
for start in (0..length).step_by(chunk_size) {
|
||||||
|
let url = url.to_string();
|
||||||
|
let filename = filename.clone();
|
||||||
|
let client = self.client.clone();
|
||||||
|
|
||||||
|
let stop = std::cmp::min(start + chunk_size - 1, length);
|
||||||
|
let permit = semaphore.clone().acquire_owned().await?;
|
||||||
|
let parallel_failures = self.parallel_failures;
|
||||||
|
let max_retries = self.max_retries;
|
||||||
|
let parallel_failures_semaphore = parallel_failures_semaphore.clone();
|
||||||
|
let progress = progressbar.clone();
|
||||||
|
handles.push(tokio::spawn(async move {
|
||||||
|
let mut chunk = Self::download_chunk(&client, &url, &filename, start, stop).await;
|
||||||
|
let mut i = 0;
|
||||||
|
if parallel_failures > 0 {
|
||||||
|
while let Err(dlerr) = chunk {
|
||||||
|
let parallel_failure_permit =
|
||||||
|
parallel_failures_semaphore.clone().try_acquire_owned()?;
|
||||||
|
|
||||||
|
let wait_time = exponential_backoff(300, i, 10_000);
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(wait_time as u64))
|
||||||
|
.await;
|
||||||
|
|
||||||
|
chunk = Self::download_chunk(&client, &url, &filename, start, stop).await;
|
||||||
|
i += 1;
|
||||||
|
if i > max_retries {
|
||||||
|
return Err(ApiError::TooManyRetries(dlerr.into()));
|
||||||
|
}
|
||||||
|
drop(parallel_failure_permit);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
drop(permit);
|
||||||
|
if let Some(p) = progress {
|
||||||
|
p.inc((stop - start) as u64);
|
||||||
|
}
|
||||||
|
chunk
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Output the chained result
|
||||||
|
let results: Vec<Result<Result<(), ApiError>, tokio::task::JoinError>> =
|
||||||
|
futures::future::join_all(handles).await;
|
||||||
|
let results: Result<(), ApiError> = results.into_iter().flatten().collect();
|
||||||
|
results?;
|
||||||
|
if let Some(p) = progressbar {
|
||||||
|
p.finish()
|
||||||
|
}
|
||||||
|
Ok(filename)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn download_chunk(
|
||||||
|
client: &reqwest::Client,
|
||||||
|
url: &str,
|
||||||
|
filename: &PathBuf,
|
||||||
|
start: usize,
|
||||||
|
stop: usize,
|
||||||
|
) -> Result<(), ApiError> {
|
||||||
|
// Process each socket concurrently.
|
||||||
|
let range = format!("bytes={start}-{stop}");
|
||||||
|
let mut file = tokio::fs::OpenOptions::new()
|
||||||
|
.write(true)
|
||||||
|
.create(true)
|
||||||
|
.open(filename)
|
||||||
|
.await?;
|
||||||
|
file.seek(SeekFrom::Start(start as u64)).await?;
|
||||||
|
let response = client
|
||||||
|
.get(url)
|
||||||
|
.header(RANGE, range)
|
||||||
|
.send()
|
||||||
|
.await?
|
||||||
|
.error_for_status()?;
|
||||||
|
let content = response.bytes().await?;
|
||||||
|
file.write_all(&content).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This will attempt the fetch the file locally first, then [`Api.download`]
|
||||||
|
/// if the file is not present.
|
||||||
|
/// ```no_run
|
||||||
|
/// # use candle_hub::{api::ApiBuilder, Repo};
|
||||||
|
/// # tokio_test::block_on(async {
|
||||||
|
/// let api = ApiBuilder::new().build().unwrap();
|
||||||
|
/// let repo = Repo::model("gpt2".to_string());
|
||||||
|
/// let local_filename = api.get(&repo, "model.safetensors").await.unwrap();
|
||||||
|
/// # })
|
||||||
|
pub async fn get(&self, repo: &Repo, filename: &str) -> Result<PathBuf, ApiError> {
|
||||||
|
if let Some(path) = self.cache.get(repo, filename) {
|
||||||
|
Ok(path)
|
||||||
|
} else {
|
||||||
|
self.download(repo, filename).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Downloads a remote file (if not already present) into the cache directory
|
||||||
|
/// to be used locally.
|
||||||
|
/// This functions require internet access to verify if new versions of the file
|
||||||
|
/// exist, even if a file is already on disk at location.
|
||||||
|
/// ```no_run
|
||||||
|
/// # use candle_hub::{api::ApiBuilder, Repo};
|
||||||
|
/// # tokio_test::block_on(async {
|
||||||
|
/// let api = ApiBuilder::new().build().unwrap();
|
||||||
|
/// let repo = Repo::model("gpt2".to_string());
|
||||||
|
/// let local_filename = api.download(&repo, "model.safetensors").await.unwrap();
|
||||||
|
/// # })
|
||||||
|
/// ```
|
||||||
|
pub async fn download(&self, repo: &Repo, filename: &str) -> Result<PathBuf, ApiError> {
|
||||||
|
let url = self.url(repo, filename);
|
||||||
|
let metadata = self.metadata(&url).await?;
|
||||||
|
|
||||||
|
let blob_path = self.cache.blob_path(repo, &metadata.etag);
|
||||||
|
std::fs::create_dir_all(blob_path.parent().unwrap())?;
|
||||||
|
|
||||||
|
let file1 = std::fs::OpenOptions::new()
|
||||||
|
.read(true)
|
||||||
|
.write(true)
|
||||||
|
.create(true)
|
||||||
|
.open(&blob_path)?;
|
||||||
|
file1.lock_exclusive()?;
|
||||||
|
|
||||||
|
let progressbar = if self.progress {
|
||||||
|
let progress = ProgressBar::new(metadata.size as u64);
|
||||||
|
progress.set_style(
|
||||||
|
ProgressStyle::with_template(
|
||||||
|
"{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})",
|
||||||
|
)
|
||||||
|
.unwrap(), // .progress_chars("━ "),
|
||||||
|
);
|
||||||
|
let maxlength = 30;
|
||||||
|
let message = if filename.len() > maxlength {
|
||||||
|
format!("..{}", &filename[filename.len() - maxlength..])
|
||||||
|
} else {
|
||||||
|
filename.to_string()
|
||||||
|
};
|
||||||
|
progress.set_message(message);
|
||||||
|
Some(progress)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let tmp_filename = self
|
||||||
|
.download_tempfile(&url, metadata.size, progressbar)
|
||||||
|
.await?;
|
||||||
|
std::fs::copy(tmp_filename, &blob_path)?;
|
||||||
|
|
||||||
|
let mut pointer_path = self.cache.pointer_path(repo, &metadata.commit_hash);
|
||||||
|
pointer_path.push(filename);
|
||||||
|
std::fs::create_dir_all(pointer_path.parent().unwrap()).ok();
|
||||||
|
|
||||||
|
symlink_or_rename(&blob_path, &pointer_path)?;
|
||||||
|
self.cache.create_ref(repo, &metadata.commit_hash)?;
|
||||||
|
|
||||||
|
Ok(pointer_path)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get information about the Repo
|
||||||
|
/// ```
|
||||||
|
/// # use candle_hub::{api::Api, Repo};
|
||||||
|
/// # tokio_test::block_on(async {
|
||||||
|
/// let api = Api::new().unwrap();
|
||||||
|
/// let repo = Repo::model("gpt2".to_string());
|
||||||
|
/// api.info(&repo);
|
||||||
|
/// # })
|
||||||
|
/// ```
|
||||||
|
pub async fn info(&self, repo: &Repo) -> Result<ModelInfo, ApiError> {
|
||||||
|
let url = format!("{}/api/{}", self.endpoint, repo.api_url());
|
||||||
|
let response = self.client.get(url).send().await?;
|
||||||
|
let response = response.error_for_status()?;
|
||||||
|
|
||||||
|
let model_info = response.json().await?;
|
||||||
|
|
||||||
|
Ok(model_info)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::RepoType;
|
||||||
|
use rand::{distributions::Alphanumeric, Rng};
|
||||||
|
use sha256::try_digest;
|
||||||
|
|
||||||
|
struct TempDir {
|
||||||
|
path: PathBuf,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TempDir {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
let s: String = rand::thread_rng()
|
||||||
|
.sample_iter(&Alphanumeric)
|
||||||
|
.take(7)
|
||||||
|
.map(char::from)
|
||||||
|
.collect();
|
||||||
|
let mut path = std::env::temp_dir();
|
||||||
|
path.push(s);
|
||||||
|
std::fs::create_dir(&path).unwrap();
|
||||||
|
Self { path }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for TempDir {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
std::fs::remove_dir_all(&self.path).unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn simple() {
|
||||||
|
let tmp = TempDir::new();
|
||||||
|
let api = ApiBuilder::new()
|
||||||
|
.with_progress(false)
|
||||||
|
.with_cache_dir(tmp.path.clone())
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
let repo = Repo::new("julien-c/dummy-unknown".to_string(), RepoType::Model);
|
||||||
|
let downloaded_path = api.download(&repo, "config.json").await.unwrap();
|
||||||
|
assert!(downloaded_path.exists());
|
||||||
|
let val = try_digest(&*downloaded_path).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
val,
|
||||||
|
"b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Make sure the file is now seeable without connection
|
||||||
|
let cache_path = api.cache.get(&repo, "config.json").unwrap();
|
||||||
|
assert_eq!(cache_path, downloaded_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn dataset() {
|
||||||
|
let tmp = TempDir::new();
|
||||||
|
let api = ApiBuilder::new()
|
||||||
|
.with_progress(false)
|
||||||
|
.with_cache_dir(tmp.path.clone())
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
let repo = Repo::with_revision(
|
||||||
|
"wikitext".to_string(),
|
||||||
|
RepoType::Dataset,
|
||||||
|
"refs/convert/parquet".to_string(),
|
||||||
|
);
|
||||||
|
let downloaded_path = api
|
||||||
|
.download(&repo, "wikitext-103-v1/wikitext-test.parquet")
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(downloaded_path.exists());
|
||||||
|
let val = try_digest(&*downloaded_path).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
val,
|
||||||
|
"59ce09415ad8aa45a9e34f88cec2548aeb9de9a73fcda9f6b33a86a065f32b90"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn info() {
|
||||||
|
let tmp = TempDir::new();
|
||||||
|
let api = ApiBuilder::new()
|
||||||
|
.with_progress(false)
|
||||||
|
.with_cache_dir(tmp.path.clone())
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
let repo = Repo::with_revision(
|
||||||
|
"wikitext".to_string(),
|
||||||
|
RepoType::Dataset,
|
||||||
|
"refs/convert/parquet".to_string(),
|
||||||
|
);
|
||||||
|
let model_info = api.info(&repo).await.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
model_info,
|
||||||
|
ModelInfo {
|
||||||
|
siblings: vec![
|
||||||
|
Siblings {
|
||||||
|
rfilename: ".gitattributes".to_string(),
|
||||||
|
},
|
||||||
|
Siblings {
|
||||||
|
rfilename: "wikitext-103-raw-v1/wikitext-test.parquet".to_string(),
|
||||||
|
},
|
||||||
|
Siblings {
|
||||||
|
rfilename: "wikitext-103-raw-v1/wikitext-train-00000-of-00002.parquet"
|
||||||
|
.to_string(),
|
||||||
|
},
|
||||||
|
Siblings {
|
||||||
|
rfilename: "wikitext-103-raw-v1/wikitext-train-00001-of-00002.parquet"
|
||||||
|
.to_string(),
|
||||||
|
},
|
||||||
|
Siblings {
|
||||||
|
rfilename: "wikitext-103-raw-v1/wikitext-validation.parquet".to_string(),
|
||||||
|
},
|
||||||
|
Siblings {
|
||||||
|
rfilename: "wikitext-103-v1/wikitext-test.parquet".to_string(),
|
||||||
|
},
|
||||||
|
Siblings {
|
||||||
|
rfilename: "wikitext-103-v1/wikitext-train-00000-of-00002.parquet"
|
||||||
|
.to_string(),
|
||||||
|
},
|
||||||
|
Siblings {
|
||||||
|
rfilename: "wikitext-103-v1/wikitext-train-00001-of-00002.parquet"
|
||||||
|
.to_string(),
|
||||||
|
},
|
||||||
|
Siblings {
|
||||||
|
rfilename: "wikitext-103-v1/wikitext-validation.parquet".to_string(),
|
||||||
|
},
|
||||||
|
Siblings {
|
||||||
|
rfilename: "wikitext-2-raw-v1/wikitext-test.parquet".to_string(),
|
||||||
|
},
|
||||||
|
Siblings {
|
||||||
|
rfilename: "wikitext-2-raw-v1/wikitext-train.parquet".to_string(),
|
||||||
|
},
|
||||||
|
Siblings {
|
||||||
|
rfilename: "wikitext-2-raw-v1/wikitext-validation.parquet".to_string(),
|
||||||
|
},
|
||||||
|
Siblings {
|
||||||
|
rfilename: "wikitext-2-v1/wikitext-test.parquet".to_string(),
|
||||||
|
},
|
||||||
|
Siblings {
|
||||||
|
rfilename: "wikitext-2-v1/wikitext-train.parquet".to_string(),
|
||||||
|
},
|
||||||
|
Siblings {
|
||||||
|
rfilename: "wikitext-2-v1/wikitext-validation.parquet".to_string(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
187
candle-hub/src/lib.rs
Normal file
187
candle-hub/src/lib.rs
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
#![deny(missing_docs)]
|
||||||
|
//! This crates aims to emulate and be compatible with the
|
||||||
|
//! [huggingface_hub](https://github.com/huggingface/huggingface_hub/) python package.
|
||||||
|
//!
|
||||||
|
//! compatible means the Api should reuse the same files skipping downloads if
|
||||||
|
//! they are already present and whenever this crate downloads or modifies this cache
|
||||||
|
//! it should be consistent with [huggingface_hub](https://github.com/huggingface/huggingface_hub/)
|
||||||
|
//!
|
||||||
|
//! At this time only a limited subset of the functionality is present, the goal is to add new
|
||||||
|
//! features over time
|
||||||
|
use std::io::Write;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
/// The actual Api to interact with the hub.
|
||||||
|
#[cfg(feature = "online")]
|
||||||
|
pub mod api;
|
||||||
|
|
||||||
|
/// The type of repo to interact with
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub enum RepoType {
|
||||||
|
/// This is a model, usually it consists of weight files and some configuration
|
||||||
|
/// files
|
||||||
|
Model,
|
||||||
|
/// This is a dataset, usually contains data within parquet files
|
||||||
|
Dataset,
|
||||||
|
/// This is a space, usually a demo showcashing a given model or dataset
|
||||||
|
Space,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A local struct used to fetch information from the cache folder.
|
||||||
|
pub struct Cache {
|
||||||
|
path: PathBuf,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Cache {
|
||||||
|
/// Creates a new cache object location
|
||||||
|
pub fn new(path: PathBuf) -> Self {
|
||||||
|
Self { path }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a new cache object location
|
||||||
|
pub fn path(&self) -> &PathBuf {
|
||||||
|
&self.path
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This will get the location of the file within the cache for the remote
|
||||||
|
/// `filename`. Will return `None` if file is not already present in cache.
|
||||||
|
pub fn get(&self, repo: &Repo, filename: &str) -> Option<PathBuf> {
|
||||||
|
let mut commit_path = self.path.clone();
|
||||||
|
commit_path.push(repo.folder_name());
|
||||||
|
commit_path.push("refs");
|
||||||
|
commit_path.push(repo.revision());
|
||||||
|
let commit_hash = std::fs::read_to_string(commit_path).ok()?;
|
||||||
|
let mut pointer_path = self.pointer_path(repo, &commit_hash);
|
||||||
|
pointer_path.push(filename);
|
||||||
|
Some(pointer_path)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a reference in the cache directory that points branches to the correct
|
||||||
|
/// commits within the blobs.
|
||||||
|
pub fn create_ref(&self, repo: &Repo, commit_hash: &str) -> Result<(), std::io::Error> {
|
||||||
|
let mut ref_path = self.path.clone();
|
||||||
|
ref_path.push(repo.folder_name());
|
||||||
|
ref_path.push("refs");
|
||||||
|
ref_path.push(repo.revision());
|
||||||
|
// Needs to be done like this because revision might contain `/` creating subfolders here.
|
||||||
|
std::fs::create_dir_all(ref_path.parent().unwrap())?;
|
||||||
|
let mut file1 = std::fs::OpenOptions::new()
|
||||||
|
.write(true)
|
||||||
|
.create(true)
|
||||||
|
.open(&ref_path)?;
|
||||||
|
file1.write_all(commit_hash.trim().as_bytes())?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "online")]
|
||||||
|
pub(crate) fn blob_path(&self, repo: &Repo, etag: &str) -> PathBuf {
|
||||||
|
let mut blob_path = self.path.clone();
|
||||||
|
blob_path.push(repo.folder_name());
|
||||||
|
blob_path.push("blobs");
|
||||||
|
blob_path.push(etag);
|
||||||
|
blob_path
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn pointer_path(&self, repo: &Repo, commit_hash: &str) -> PathBuf {
|
||||||
|
let mut pointer_path = self.path.clone();
|
||||||
|
pointer_path.push(repo.folder_name());
|
||||||
|
pointer_path.push("snapshots");
|
||||||
|
pointer_path.push(commit_hash);
|
||||||
|
pointer_path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for Cache {
|
||||||
|
fn default() -> Self {
|
||||||
|
let path = match std::env::var("HF_HOME") {
|
||||||
|
Ok(home) => home.into(),
|
||||||
|
Err(_) => {
|
||||||
|
let mut cache = dirs::home_dir().expect("Cache directory cannot be found");
|
||||||
|
cache.push(".cache");
|
||||||
|
cache.push("huggingface");
|
||||||
|
cache
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Self::new(path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The representation of a repo on the hub.
|
||||||
|
#[allow(dead_code)] // Repo type unused in offline mode
|
||||||
|
pub struct Repo {
|
||||||
|
repo_id: String,
|
||||||
|
repo_type: RepoType,
|
||||||
|
revision: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Repo {
|
||||||
|
/// Repo with the default branch ("main").
|
||||||
|
pub fn new(repo_id: String, repo_type: RepoType) -> Self {
|
||||||
|
Self::with_revision(repo_id, repo_type, "main".to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// fully qualified Repo
|
||||||
|
pub fn with_revision(repo_id: String, repo_type: RepoType, revision: String) -> Self {
|
||||||
|
Self {
|
||||||
|
repo_id,
|
||||||
|
repo_type,
|
||||||
|
revision,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Shortcut for [`Repo::new`] with [`RepoType::Model`]
|
||||||
|
pub fn model(repo_id: String) -> Self {
|
||||||
|
Self::new(repo_id, RepoType::Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Shortcut for [`Repo::new`] with [`RepoType::Dataset`]
|
||||||
|
pub fn dataset(repo_id: String) -> Self {
|
||||||
|
Self::new(repo_id, RepoType::Dataset)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Shortcut for [`Repo::new`] with [`RepoType::Space`]
|
||||||
|
pub fn space(repo_id: String) -> Self {
|
||||||
|
Self::new(repo_id, RepoType::Space)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The normalized folder nameof the repo within the cache directory
|
||||||
|
pub fn folder_name(&self) -> String {
|
||||||
|
self.repo_id.replace('/', "--")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The revision
|
||||||
|
pub fn revision(&self) -> &str {
|
||||||
|
&self.revision
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The actual URL part of the repo
|
||||||
|
#[cfg(feature = "online")]
|
||||||
|
pub fn url(&self) -> String {
|
||||||
|
match self.repo_type {
|
||||||
|
RepoType::Model => self.repo_id.to_string(),
|
||||||
|
RepoType::Dataset => {
|
||||||
|
format!("datasets/{}", self.repo_id)
|
||||||
|
}
|
||||||
|
RepoType::Space => {
|
||||||
|
format!("spaces/{}", self.repo_id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Revision needs to be url escaped before being used in a URL
|
||||||
|
#[cfg(feature = "online")]
|
||||||
|
pub fn url_revision(&self) -> String {
|
||||||
|
self.revision.replace('/', "%2F")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Used to compute the repo's url part when accessing the metadata of the repo
|
||||||
|
#[cfg(feature = "online")]
|
||||||
|
pub fn api_url(&self) -> String {
|
||||||
|
let prefix = match self.repo_type {
|
||||||
|
RepoType::Model => "models",
|
||||||
|
RepoType::Dataset => "datasets",
|
||||||
|
RepoType::Space => "spaces",
|
||||||
|
};
|
||||||
|
format!("{prefix}/{}/revision/{}", self.repo_id, self.url_revision())
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user