diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 98cad54f..77441374 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -23,7 +23,6 @@ candle-hub = { path = "../candle-hub" } clap = { version = "4.2.4", features = ["derive"] } rand = "0.8.5" tokenizers = { version = "0.13.3", default-features=false, features=["onig"] } -tokio = { version = "1.28.2", features = ["macros", "rt-multi-thread"] } wav = "1.0.0" [features] diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index f9f83c4a..8b292f92 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -5,7 +5,7 @@ extern crate intel_mkl_src; use anyhow::{anyhow, Error as E, Result}; use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor}; -use candle_hub::{api::Api, Cache, Repo, RepoType}; +use candle_hub::{api::sync::Api, Cache, Repo, RepoType}; use clap::Parser; use serde::Deserialize; use std::collections::HashMap; @@ -656,7 +656,7 @@ struct Args { } impl Args { - async fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> { + fn build_model_and_tokenizer(&self) -> Result<(BertModel, Tokenizer)> { let device = if self.cpu { Device::Cpu } else { @@ -688,9 +688,9 @@ impl Args { } else { let api = Api::new()?; ( - api.get(&repo, "config.json").await?, - api.get(&repo, "tokenizer.json").await?, - api.get(&repo, "model.safetensors").await?, + api.get(&repo, "config.json")?, + api.get(&repo, "tokenizer.json")?, + api.get(&repo, "model.safetensors")?, ) }; let config = std::fs::read_to_string(config_filename)?; @@ -705,12 +705,11 @@ impl Args { } } -#[tokio::main] -async fn main() -> Result<()> { +fn main() -> Result<()> { let start = std::time::Instant::now(); let args = Args::parse(); - let (model, mut tokenizer) = args.build_model_and_tokenizer().await?; + let (model, mut tokenizer) = args.build_model_and_tokenizer()?; let device = &model.device; if let Some(prompt) = args.prompt { diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index 1fba7bbd..6b0cafb9 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -19,7 +19,7 @@ use clap::Parser; use rand::{distributions::Distribution, SeedableRng}; use candle::{DType, Device, Tensor, D}; -use candle_hub::{api::Api, Repo, RepoType}; +use candle_hub::{api::sync::Api, Repo, RepoType}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -465,8 +465,7 @@ struct Args { prompt: Option, } -#[tokio::main] -async fn main() -> Result<()> { +fn main() -> Result<()> { use tokenizers::Tokenizer; let args = Args::parse(); @@ -489,13 +488,13 @@ async fn main() -> Result<()> { let api = Api::new()?; let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model); println!("building the model"); - let tokenizer_filename = api.get(&repo, "tokenizer.json").await?; + let tokenizer_filename = api.get(&repo, "tokenizer.json")?; let mut filenames = vec![]; for rfilename in [ "model-00001-of-00002.safetensors", "model-00002-of-00002.safetensors", ] { - let filename = api.get(&repo, rfilename).await?; + let filename = api.get(&repo, rfilename)?; filenames.push(filename); } diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 0c9d0893..949bded1 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -11,7 +11,7 @@ extern crate intel_mkl_src; use anyhow::{Error as E, Result}; use candle::{DType, Device, Tensor}; -use candle_hub::{api::Api, Repo, RepoType}; +use candle_hub::{api::sync::Api, Repo, RepoType}; use clap::Parser; use rand::{distributions::Distribution, SeedableRng}; use tokenizers::Tokenizer; @@ -253,8 +253,7 @@ struct Args { filters: String, } -#[tokio::main] -async fn main() -> Result<()> { +fn main() -> Result<()> { let args = Args::parse(); let device = if args.cpu { Device::Cpu @@ -276,7 +275,7 @@ async fn main() -> Result<()> { config_filename.push("config.json"); let mut tokenizer_filename = path.clone(); tokenizer_filename.push("tokenizer.json"); - let mut model_filename = path.clone(); + let mut model_filename = path; model_filename.push("model.safetensors"); ( config_filename, @@ -288,9 +287,9 @@ async fn main() -> Result<()> { let repo = Repo::with_revision(model_id, RepoType::Model, revision); let api = Api::new()?; ( - api.get(&repo, "config.json").await?, - api.get(&repo, "tokenizer.json").await?, - api.get(&repo, "model.safetensors").await?, + api.get(&repo, "config.json")?, + api.get(&repo, "tokenizer.json")?, + api.get(&repo, "model.safetensors")?, if let Some(input) = args.input { std::path::PathBuf::from(input) } else { @@ -298,8 +297,7 @@ async fn main() -> Result<()> { api.get( &Repo::new("Narsil/candle-examples".to_string(), RepoType::Dataset), "samples_jfk.wav", - ) - .await? + )? }, ) }; diff --git a/candle-hub/Cargo.toml b/candle-hub/Cargo.toml index 31fbe12f..18e5d6c2 100644 --- a/candle-hub/Cargo.toml +++ b/candle-hub/Cargo.toml @@ -15,6 +15,7 @@ tokio = { version = "1.28.2", features = ["fs"], optional = true } serde = { version = "1.0.164", features = ["derive"], optional = true } serde_json = { version = "1.0.97", optional = true } indicatif = { version = "0.17.5", optional = true } +num_cpus = { version = "1.16.0", optional = true } [dev-dependencies] rand = "0.8.5" @@ -24,6 +25,7 @@ tokio-test = "0.4.2" [features] default = ["online"] -online = ["reqwest", "tokio", "futures", "serde", "serde_json", "indicatif"] +online = ["reqwest/blocking", "dep:serde", "dep:serde_json", "dep:indicatif", "dep:num_cpus"] +tokio = ["online", "dep:tokio", "dep:futures"] diff --git a/candle-hub/src/api/mod.rs b/candle-hub/src/api/mod.rs new file mode 100644 index 00000000..779dc4f9 --- /dev/null +++ b/candle-hub/src/api/mod.rs @@ -0,0 +1,6 @@ +/// The asynchronous version of the API +#[cfg(feature = "tokio")] +pub mod tokio; + +/// The synchronous version of the API +pub mod sync; diff --git a/candle-hub/src/api/sync.rs b/candle-hub/src/api/sync.rs new file mode 100644 index 00000000..6efdbff8 --- /dev/null +++ b/candle-hub/src/api/sync.rs @@ -0,0 +1,686 @@ +use crate::{Cache, Repo}; +use indicatif::{ProgressBar, ProgressStyle}; +use rand::{distributions::Alphanumeric, thread_rng, Rng}; +use reqwest::{ + blocking::Client, + header::{ + HeaderMap, HeaderName, HeaderValue, InvalidHeaderValue, ToStrError, AUTHORIZATION, + CONTENT_RANGE, LOCATION, RANGE, USER_AGENT, + }, + redirect::Policy, + Error as ReqwestError, +}; +use serde::Deserialize; +use std::io::{Seek, SeekFrom, Write}; +use std::num::ParseIntError; +use std::path::{Component, Path, PathBuf}; +use thiserror::Error; + +/// Current version (used in user-agent) +const VERSION: &str = env!("CARGO_PKG_VERSION"); +/// Current name (used in user-agent) +const NAME: &str = env!("CARGO_PKG_NAME"); + +#[derive(Debug, Error)] +/// All errors the API can throw +pub enum ApiError { + /// Api expects certain header to be present in the results to derive some information + #[error("Header {0} is missing")] + MissingHeader(HeaderName), + + /// The header exists, but the value is not conform to what the Api expects. + #[error("Header {0} is invalid")] + InvalidHeader(HeaderName), + + /// The value cannot be used as a header during request header construction + #[error("Invalid header value {0}")] + InvalidHeaderValue(#[from] InvalidHeaderValue), + + /// The header value is not valid utf-8 + #[error("header value is not a string")] + ToStr(#[from] ToStrError), + + /// Error in the request + #[error("request error: {0}")] + RequestError(#[from] ReqwestError), + + /// Error parsing some range value + #[error("Cannot parse int")] + ParseIntError(#[from] ParseIntError), + + /// I/O Error + #[error("I/O error {0}")] + IoError(#[from] std::io::Error), + + /// We tried to download chunk too many times + #[error("Too many retries: {0}")] + TooManyRetries(Box), +} + +/// Siblings are simplified file descriptions of remote files on the hub +#[derive(Debug, Clone, Deserialize, PartialEq)] +pub struct Siblings { + /// The path within the repo. + pub rfilename: String, +} + +/// The description of the repo given by the hub +#[derive(Debug, Clone, Deserialize, PartialEq)] +pub struct ModelInfo { + /// See [`Siblings`] + pub siblings: Vec, +} + +/// Helper to create [`Api`] with all the options. +pub struct ApiBuilder { + endpoint: String, + cache: Cache, + url_template: String, + token: Option, + chunk_size: usize, + parallel_failures: usize, + max_retries: usize, + progress: bool, +} + +impl Default for ApiBuilder { + fn default() -> Self { + Self::new() + } +} + +impl ApiBuilder { + /// Default api builder + /// ``` + /// use candle_hub::api::sync::ApiBuilder; + /// let api = ApiBuilder::new().build().unwrap(); + /// ``` + pub fn new() -> Self { + let cache = Cache::default(); + let mut token_filename = cache.path().clone(); + token_filename.push(".token"); + let token = match std::fs::read_to_string(token_filename) { + Ok(token_content) => { + let token_content = token_content.trim(); + if !token_content.is_empty() { + Some(token_content.to_string()) + } else { + None + } + } + Err(_) => None, + }; + + let progress = true; + + Self { + endpoint: "https://huggingface.co".to_string(), + url_template: "{endpoint}/{repo_id}/resolve/{revision}/{filename}".to_string(), + cache, + token, + chunk_size: 10_000_000, + parallel_failures: 0, + max_retries: 0, + progress, + } + } + + /// Wether to show a progressbar + pub fn with_progress(mut self, progress: bool) -> Self { + self.progress = progress; + self + } + + /// Changes the location of the cache directory. Defaults is `~/.cache/huggingface/`. + pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self { + self.cache = Cache::new(cache_dir); + self + } + + fn build_headers(&self) -> Result { + let mut headers = HeaderMap::new(); + let user_agent = format!("unkown/None; {NAME}/{VERSION}; rust/unknown"); + headers.insert(USER_AGENT, HeaderValue::from_str(&user_agent)?); + if let Some(token) = &self.token { + headers.insert( + AUTHORIZATION, + HeaderValue::from_str(&format!("Bearer {token}"))?, + ); + } + Ok(headers) + } + + /// Consumes the builder and buids the final [`Api`] + pub fn build(self) -> Result { + let headers = self.build_headers()?; + let client = Client::builder().default_headers(headers.clone()).build()?; + let no_redirect_client = Client::builder() + .redirect(Policy::none()) + .default_headers(headers) + .build()?; + Ok(Api { + endpoint: self.endpoint, + url_template: self.url_template, + cache: self.cache, + client, + + no_redirect_client, + chunk_size: self.chunk_size, + parallel_failures: self.parallel_failures, + max_retries: self.max_retries, + progress: self.progress, + }) + } +} + +#[derive(Debug)] +struct Metadata { + commit_hash: String, + etag: String, + size: usize, +} + +/// The actual Api used to interacto with the hub. +/// You can inspect repos with [`Api::info`] +/// or download files with [`Api::download`] +pub struct Api { + endpoint: String, + url_template: String, + cache: Cache, + client: Client, + no_redirect_client: Client, + chunk_size: usize, + parallel_failures: usize, + max_retries: usize, + progress: bool, +} + +fn temp_filename() -> PathBuf { + let s: String = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(7) + .map(char::from) + .collect(); + let mut path = std::env::temp_dir(); + path.push(s); + path +} + +fn make_relative(src: &Path, dst: &Path) -> PathBuf { + let path = src; + let base = dst; + + if path.is_absolute() != base.is_absolute() { + panic!("This function is made to look at absolute paths only"); + } + let mut ita = path.components(); + let mut itb = base.components(); + + loop { + match (ita.next(), itb.next()) { + (Some(a), Some(b)) if a == b => (), + (some_a, _) => { + // Ignoring b, because 1 component is the filename + // for which we don't need to go back up for relative + // filename to work. + let mut new_path = PathBuf::new(); + for _ in itb { + new_path.push(Component::ParentDir); + } + if let Some(a) = some_a { + new_path.push(a); + for comp in ita { + new_path.push(comp); + } + } + return new_path; + } + } + } +} + +fn symlink_or_rename(src: &Path, dst: &Path) -> Result<(), std::io::Error> { + if dst.exists() { + return Ok(()); + } + + let src = make_relative(src, dst); + #[cfg(target_os = "windows")] + std::os::windows::fs::symlink_file(src, dst)?; + + #[cfg(target_family = "unix")] + std::os::unix::fs::symlink(src, dst)?; + + #[cfg(not(any(target_family = "unix", target_os = "windows")))] + std::fs::rename(src, dst)?; + + Ok(()) +} + +fn jitter() -> usize { + thread_rng().gen_range(0..=500) +} + +fn exponential_backoff(base_wait_time: usize, n: usize, max: usize) -> usize { + (base_wait_time + n.pow(2) + jitter()).min(max) +} + +impl Api { + /// Creates a default Api, for Api options See [`ApiBuilder`] + pub fn new() -> Result { + ApiBuilder::new().build() + } + + /// Get the fully qualified URL of the remote filename + /// ``` + /// # use candle_hub::{api::sync::Api, Repo}; + /// let api = Api::new().unwrap(); + /// let repo = Repo::model("gpt2".to_string()); + /// let url = api.url(&repo, "model.safetensors"); + /// assert_eq!(url, "https://huggingface.co/gpt2/resolve/main/model.safetensors"); + /// ``` + pub fn url(&self, repo: &Repo, filename: &str) -> String { + let endpoint = &self.endpoint; + let revision = &repo.url_revision(); + self.url_template + .replace("{endpoint}", endpoint) + .replace("{repo_id}", &repo.url()) + .replace("{revision}", revision) + .replace("{filename}", filename) + } + + /// Get the underlying api client + /// Allows for lower level access + pub fn client(&self) -> &Client { + &self.client + } + + fn metadata(&self, url: &str) -> Result { + let response = self + .no_redirect_client + .get(url) + .header(RANGE, "bytes=0-0") + .send()?; + let response = response.error_for_status()?; + let headers = response.headers(); + let header_commit = HeaderName::from_static("x-repo-commit"); + let header_linked_etag = HeaderName::from_static("x-linked-etag"); + let header_etag = HeaderName::from_static("etag"); + + let etag = match headers.get(&header_linked_etag) { + Some(etag) => etag, + None => headers + .get(&header_etag) + .ok_or(ApiError::MissingHeader(header_etag))?, + }; + // Cleaning extra quotes + let etag = etag.to_str()?.to_string().replace('"', ""); + let commit_hash = headers + .get(&header_commit) + .ok_or(ApiError::MissingHeader(header_commit))? + .to_str()? + .to_string(); + + // The response was redirected o S3 most likely which will + // know about the size of the file + let response = if response.status().is_redirection() { + self.client + .get(headers.get(LOCATION).unwrap().to_str()?.to_string()) + .header(RANGE, "bytes=0-0") + .send()? + } else { + response + }; + let headers = response.headers(); + let content_range = headers + .get(CONTENT_RANGE) + .ok_or(ApiError::MissingHeader(CONTENT_RANGE))? + .to_str()?; + + let size = content_range + .split('/') + .last() + .ok_or(ApiError::InvalidHeader(CONTENT_RANGE))? + .parse()?; + Ok(Metadata { + commit_hash, + etag, + size, + }) + } + + fn download_tempfile( + &self, + url: &str, + length: usize, + progressbar: Option, + ) -> Result { + let filename = temp_filename(); + + // Create the file and set everything properly + std::fs::File::create(&filename)?.set_len(length as u64)?; + + let chunk_size = self.chunk_size; + + let n_chunks = (length + chunk_size - 1) / chunk_size; + let n_threads = num_cpus::get(); + let chunks_per_thread = (n_chunks + n_threads - 1) / n_threads; + let handles = (0..n_threads).map(|thread_id| { + let url = url.to_string(); + let filename = filename.clone(); + let client = self.client.clone(); + let parallel_failures = self.parallel_failures; + let max_retries = self.max_retries; + let progress = progressbar.clone(); + std::thread::spawn(move || { + for chunk_id in chunks_per_thread * thread_id + ..std::cmp::min(chunks_per_thread * (thread_id + 1), n_chunks) + { + let start = chunk_id * chunk_size; + let stop = std::cmp::min(start + chunk_size - 1, length); + let mut chunk = Self::download_chunk(&client, &url, &filename, start, stop); + let mut i = 0; + if parallel_failures > 0 { + while let Err(dlerr) = chunk { + let wait_time = exponential_backoff(300, i, 10_000); + std::thread::sleep(std::time::Duration::from_millis(wait_time as u64)); + + chunk = Self::download_chunk(&client, &url, &filename, start, stop); + i += 1; + if i > max_retries { + return Err(ApiError::TooManyRetries(dlerr.into())); + } + } + } + if let Some(p) = &progress { + p.inc((stop - start) as u64); + } + chunk? + } + Ok(()) + }) + }); + + let results: Result, ApiError> = + handles.into_iter().flat_map(|h| h.join()).collect(); + + results?; + if let Some(p) = progressbar { + p.finish() + } + Ok(filename) + } + + fn download_chunk( + client: &Client, + url: &str, + filename: &PathBuf, + start: usize, + stop: usize, + ) -> Result<(), ApiError> { + // Process each socket concurrently. + let range = format!("bytes={start}-{stop}"); + let mut file = std::fs::OpenOptions::new().write(true).open(filename)?; + file.seek(SeekFrom::Start(start as u64))?; + let response = client + .get(url) + .header(RANGE, range) + .send()? + .error_for_status()?; + let content = response.bytes()?; + file.write_all(&content)?; + Ok(()) + } + + /// This will attempt the fetch the file locally first, then [`Api.download`] + /// if the file is not present. + /// ```no_run + /// use candle_hub::{api::sync::ApiBuilder, Repo}; + /// let api = ApiBuilder::new().build().unwrap(); + /// let repo = Repo::model("gpt2".to_string()); + /// let local_filename = api.get(&repo, "model.safetensors").unwrap(); + pub fn get(&self, repo: &Repo, filename: &str) -> Result { + if let Some(path) = self.cache.get(repo, filename) { + Ok(path) + } else { + self.download(repo, filename) + } + } + + /// Downloads a remote file (if not already present) into the cache directory + /// to be used locally. + /// This functions require internet access to verify if new versions of the file + /// exist, even if a file is already on disk at location. + /// ```no_run + /// # use candle_hub::{api::sync::ApiBuilder, Repo}; + /// let api = ApiBuilder::new().build().unwrap(); + /// let repo = Repo::model("gpt2".to_string()); + /// let local_filename = api.download(&repo, "model.safetensors").unwrap(); + /// ``` + pub fn download(&self, repo: &Repo, filename: &str) -> Result { + let url = self.url(repo, filename); + let metadata = self.metadata(&url)?; + + let blob_path = self.cache.blob_path(repo, &metadata.etag); + std::fs::create_dir_all(blob_path.parent().unwrap())?; + + let progressbar = if self.progress { + let progress = ProgressBar::new(metadata.size as u64); + progress.set_style( + ProgressStyle::with_template( + "{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})", + ) + .unwrap(), // .progress_chars("━ "), + ); + let maxlength = 30; + let message = if filename.len() > maxlength { + format!("..{}", &filename[filename.len() - maxlength..]) + } else { + filename.to_string() + }; + progress.set_message(message); + Some(progress) + } else { + None + }; + + let tmp_filename = self.download_tempfile(&url, metadata.size, progressbar)?; + + if std::fs::rename(&tmp_filename, &blob_path).is_err() { + // Renaming may fail if locations are different mount points + std::fs::File::create(&blob_path)?; + std::fs::copy(tmp_filename, &blob_path)?; + } + + let mut pointer_path = self.cache.pointer_path(repo, &metadata.commit_hash); + pointer_path.push(filename); + std::fs::create_dir_all(pointer_path.parent().unwrap()).ok(); + + symlink_or_rename(&blob_path, &pointer_path)?; + self.cache.create_ref(repo, &metadata.commit_hash)?; + + Ok(pointer_path) + } + + /// Get information about the Repo + /// ``` + /// use candle_hub::{api::sync::Api, Repo}; + /// let api = Api::new().unwrap(); + /// let repo = Repo::model("gpt2".to_string()); + /// api.info(&repo); + /// ``` + pub fn info(&self, repo: &Repo) -> Result { + let url = format!("{}/api/{}", self.endpoint, repo.api_url()); + let response = self.client.get(url).send()?; + let response = response.error_for_status()?; + + let model_info = response.json()?; + + Ok(model_info) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::RepoType; + use rand::{distributions::Alphanumeric, Rng}; + use sha256::try_digest; + + struct TempDir { + path: PathBuf, + } + + impl TempDir { + pub fn new() -> Self { + let s: String = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(7) + .map(char::from) + .collect(); + let mut path = std::env::temp_dir(); + path.push(s); + std::fs::create_dir(&path).unwrap(); + Self { path } + } + } + + impl Drop for TempDir { + fn drop(&mut self) { + std::fs::remove_dir_all(&self.path).unwrap() + } + } + + #[test] + fn simple() { + let tmp = TempDir::new(); + let api = ApiBuilder::new() + .with_progress(false) + .with_cache_dir(tmp.path.clone()) + .build() + .unwrap(); + let repo = Repo::new("julien-c/dummy-unknown".to_string(), RepoType::Model); + let downloaded_path = api.download(&repo, "config.json").unwrap(); + assert!(downloaded_path.exists()); + let val = try_digest(&*downloaded_path).unwrap(); + assert_eq!( + val, + "b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32" + ); + + // Make sure the file is now seeable without connection + let cache_path = api.cache.get(&repo, "config.json").unwrap(); + assert_eq!(cache_path, downloaded_path); + } + + #[test] + fn dataset() { + let tmp = TempDir::new(); + let api = ApiBuilder::new() + .with_progress(false) + .with_cache_dir(tmp.path.clone()) + .build() + .unwrap(); + let repo = Repo::with_revision( + "wikitext".to_string(), + RepoType::Dataset, + "refs/convert/parquet".to_string(), + ); + let downloaded_path = api + .download(&repo, "wikitext-103-v1/wikitext-test.parquet") + .unwrap(); + assert!(downloaded_path.exists()); + let val = try_digest(&*downloaded_path).unwrap(); + assert_eq!( + val, + "59ce09415ad8aa45a9e34f88cec2548aeb9de9a73fcda9f6b33a86a065f32b90" + ) + } + + #[test] + fn info() { + let tmp = TempDir::new(); + let api = ApiBuilder::new() + .with_progress(false) + .with_cache_dir(tmp.path.clone()) + .build() + .unwrap(); + let repo = Repo::with_revision( + "wikitext".to_string(), + RepoType::Dataset, + "refs/convert/parquet".to_string(), + ); + let model_info = api.info(&repo).unwrap(); + assert_eq!( + model_info, + ModelInfo { + siblings: vec![ + Siblings { + rfilename: ".gitattributes".to_string() + }, + Siblings { + rfilename: "wikitext-103-raw-v1/wikitext-test.parquet".to_string() + }, + Siblings { + rfilename: "wikitext-103-raw-v1/wikitext-train-00000-of-00002.parquet" + .to_string() + }, + Siblings { + rfilename: "wikitext-103-raw-v1/wikitext-train-00001-of-00002.parquet" + .to_string() + }, + Siblings { + rfilename: "wikitext-103-raw-v1/wikitext-validation.parquet".to_string() + }, + Siblings { + rfilename: "wikitext-103-v1/test/index.duckdb".to_string() + }, + Siblings { + rfilename: "wikitext-103-v1/validation/index.duckdb".to_string() + }, + Siblings { + rfilename: "wikitext-103-v1/wikitext-test.parquet".to_string() + }, + Siblings { + rfilename: "wikitext-103-v1/wikitext-train-00000-of-00002.parquet" + .to_string() + }, + Siblings { + rfilename: "wikitext-103-v1/wikitext-train-00001-of-00002.parquet" + .to_string() + }, + Siblings { + rfilename: "wikitext-103-v1/wikitext-validation.parquet".to_string() + }, + Siblings { + rfilename: "wikitext-2-raw-v1/test/index.duckdb".to_string() + }, + Siblings { + rfilename: "wikitext-2-raw-v1/train/index.duckdb".to_string() + }, + Siblings { + rfilename: "wikitext-2-raw-v1/validation/index.duckdb".to_string() + }, + Siblings { + rfilename: "wikitext-2-raw-v1/wikitext-test.parquet".to_string() + }, + Siblings { + rfilename: "wikitext-2-raw-v1/wikitext-train.parquet".to_string() + }, + Siblings { + rfilename: "wikitext-2-raw-v1/wikitext-validation.parquet".to_string() + }, + Siblings { + rfilename: "wikitext-2-v1/wikitext-test.parquet".to_string() + }, + Siblings { + rfilename: "wikitext-2-v1/wikitext-train.parquet".to_string() + }, + Siblings { + rfilename: "wikitext-2-v1/wikitext-validation.parquet".to_string() + } + ], + } + ) + } +} diff --git a/candle-hub/src/api.rs b/candle-hub/src/api/tokio.rs similarity index 98% rename from candle-hub/src/api.rs rename to candle-hub/src/api/tokio.rs index e578cfd9..bf0c2c41 100644 --- a/candle-hub/src/api.rs +++ b/candle-hub/src/api/tokio.rs @@ -105,7 +105,7 @@ impl Default for ApiBuilder { impl ApiBuilder { /// Default api builder /// ``` - /// use candle_hub::api::ApiBuilder; + /// use candle_hub::api::tokio::ApiBuilder; /// let api = ApiBuilder::new().build().unwrap(); /// ``` pub fn new() -> Self { @@ -289,7 +289,7 @@ impl Api { /// Get the fully qualified URL of the remote filename /// ``` - /// # use candle_hub::{api::Api, Repo}; + /// # use candle_hub::{api::tokio::Api, Repo}; /// let api = Api::new().unwrap(); /// let repo = Repo::model("gpt2".to_string()); /// let url = api.url(&repo, "model.safetensors"); @@ -463,7 +463,7 @@ impl Api { /// This will attempt the fetch the file locally first, then [`Api.download`] /// if the file is not present. /// ```no_run - /// # use candle_hub::{api::ApiBuilder, Repo}; + /// # use candle_hub::{api::tokio::ApiBuilder, Repo}; /// # tokio_test::block_on(async { /// let api = ApiBuilder::new().build().unwrap(); /// let repo = Repo::model("gpt2".to_string()); @@ -482,7 +482,7 @@ impl Api { /// This functions require internet access to verify if new versions of the file /// exist, even if a file is already on disk at location. /// ```no_run - /// # use candle_hub::{api::ApiBuilder, Repo}; + /// # use candle_hub::{api::tokio::ApiBuilder, Repo}; /// # tokio_test::block_on(async { /// let api = ApiBuilder::new().build().unwrap(); /// let repo = Repo::model("gpt2".to_string()); @@ -538,7 +538,7 @@ impl Api { /// Get information about the Repo /// ``` - /// # use candle_hub::{api::Api, Repo}; + /// # use candle_hub::{api::tokio::Api, Repo}; /// # tokio_test::block_on(async { /// let api = Api::new().unwrap(); /// let repo = Repo::model("gpt2".to_string()); diff --git a/candle-hub/src/lib.rs b/candle-hub/src/lib.rs index 5d186039..0de2006a 100644 --- a/candle-hub/src/lib.rs +++ b/candle-hub/src/lib.rs @@ -103,6 +103,7 @@ impl Default for Cache { let mut cache = dirs::home_dir().expect("Cache directory cannot be found"); cache.push(".cache"); cache.push("huggingface"); + cache.push("hub"); cache } };