diff --git a/.gitignore b/.gitignore index df9a6132..9ff37524 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ # Generated by Cargo # will have compiled files and executables debug/ +data/ dist/ target/ diff --git a/Cargo.toml b/Cargo.toml index 9c8b5682..6f435ba8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,6 @@ members = [ "candle-core", "candle-examples", - "candle-hub", "candle-nn", "candle-pyo3", "candle-transformers", @@ -19,29 +18,24 @@ clap = { version = "4.2.4", features = ["derive"] } # Re-enable this once 0.9.13 as been released as it would include the cublas-f16 changes # cudarc = { version = "0.9.13", optional = true, features = ["f16"] } cudarc = { git = "https://github.com/LaurentMazare/cudarc.git", branch = "cublas-bf16", features = ["f16"] } -futures = "0.3.28" # TODO: Switch back to the official gemm implementation once the following are available. # https://github.com/sarah-ek/gemm/pull/8. # https://github.com/sarah-ek/gemm/pull/9. gemm = { git = "https://github.com/LaurentMazare/gemm.git", branch = "f16-vec-plus-wasm-simd" } +hf-hub = "0.1.0" half = { version = "2.3.1", features = ["num-traits"] } -indicatif = "0.17.5" -intel-mkl-src = { version = "0.8.1", features = ["mkl-dynamic-lp64-iomp"] } +intel-mkl-src = { version = "0.8.1", features = ["mkl-static-ilp64-iomp"] } libc = { version = "0.2.147" } log = "0.4" memmap2 = "0.7.1" num_cpus = "1.15.0" num-traits = "0.2.15" rand = "0.8.5" -reqwest = "0.11.18" safetensors = "0.3.1" -serde = { version = "1.0.166", features = ["derive"] } +serde = { version = "1.0.171", features = ["derive"] } serde_json = "1.0.99" -sha256 = "=1.1.4" thiserror = "1" tokenizers = { version = "0.13.3", default-features = false } -tokio = "1.28.2" -tokio-test = "0.4.2" tracing = "0.1.37" tracing-chrome = "0.7.1" tracing-subscriber = "0.3.7" diff --git a/candle-core/src/backend.rs b/candle-core/src/backend.rs index c897510e..018279b3 100644 --- a/candle-core/src/backend.rs +++ b/candle-core/src/backend.rs @@ -16,7 +16,7 @@ pub(crate) trait BackendStorage: Sized { fn elu(&self, _: &Layout, _: f64) -> Result; - fn sum(&self, _: &Layout, _: &[usize]) -> Result; + fn reduce_op(&self, _: crate::op::ReduceOp, _: &Layout, _: &[usize]) -> Result; fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()>; diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index c72f603f..3de11d35 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -67,6 +67,8 @@ impl Tensor { Op::Reshape(node) | Op::Broadcast(node) | Op::Sum(node, _) + | Op::Max(node, _) + | Op::Min(node, _) | Op::ToDType(node) | Op::ToDevice(node) | Op::Transpose(node, _, _) @@ -203,6 +205,12 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.broadcast_add(&grad)? } + Op::Max(_args, _sum_dims) => { + return Err(Error::BackwardNotSupported { op: "max" }) + } + Op::Min(_args, _sum_dims) => { + return Err(Error::BackwardNotSupported { op: "min" }) + } Op::ToDType(arg) => { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)? diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 6458b452..925ca112 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1,5 +1,5 @@ use crate::backend::{BackendDevice, BackendStorage}; -use crate::op::{BinaryOp, UnaryOp}; +use crate::op::{BinaryOp, ReduceOp, UnaryOp}; use crate::{DType, Error, Layout, Result, Shape, WithDType}; use half::{bf16, f16}; @@ -93,64 +93,69 @@ impl<'a> Map2 for WCond<'a> { } } -struct Sum<'a> { +struct Reduce<'a> { dst_shape: &'a Shape, - sum_dims: &'a [usize], - sum_dims_and_stride: Vec<(usize, usize)>, + reduce_dims: &'a [usize], + reduce_dims_and_stride: Vec<(usize, usize)>, + op: ReduceOp, } -impl<'a> Map1 for Sum<'a> { +impl<'a> Reduce<'a> { #[inline(always)] - fn f(&self, src: &[T], src_l: &Layout) -> Result> { - let mut dst = vec![T::zero(); self.dst_shape.elem_count()]; + fn fold_impl(&self, src: &[T], src_l: &Layout, start_elt: T, f: F) -> Result> + where + T: Clone + Copy, + F: Fn(T, T) -> T, + { + let mut dst = vec![start_elt; self.dst_shape.elem_count()]; match src_l.contiguous_offsets() { Some((o1, o2)) => { let src = &src[o1..o2]; - // Handle the case where we sum over the last dimensions separately as it is + // Handle the case where we reduce over the last dimensions separately as it is // fairly common and easy to optimize. This rely on the layout being contiguous! - // sum_dims is sorted, check if it is ranging from a to n-1. - let sum_over_last_dims = self - .sum_dims + // reduce_dims is sorted, check if it is ranging from a to n-1. + let reduce_over_last_dims = self + .reduce_dims .iter() .rev() .enumerate() .all(|(i, &v)| v == src_l.shape().rank() - 1 - i); - if sum_over_last_dims { - let sum_sz = self - .sum_dims_and_stride + if reduce_over_last_dims { + let reduce_sz = self + .reduce_dims_and_stride .iter() .map(|(u, _)| u) .product::(); let mut src_i = 0; for dst_v in dst.iter_mut() { - for &s in src[src_i..src_i + sum_sz].iter() { - *dst_v += s + for &s in src[src_i..src_i + reduce_sz].iter() { + *dst_v = f(*dst_v, s) } - src_i += sum_sz + src_i += reduce_sz } return Ok(dst); }; for (unstr_index, &src) in src.iter().enumerate() { let mut dst_index = unstr_index; - // Set the sum_dims indexes to 0. - for &(dim, stride) in self.sum_dims_and_stride.iter() { + // Set the reduce_dims indexes to 0. + for &(dim, stride) in self.reduce_dims_and_stride.iter() { // The compiler is able to optimize the following in a single divmod op. let (pre, post) = (dst_index / stride, dst_index % stride); dst_index = (pre / dim) * stride + post; } - dst[dst_index] += src; + dst[dst_index] = f(dst[dst_index], src); } } None => { for (unstr_index, src_index) in src_l.strided_index().enumerate() { let mut dst_index = unstr_index; - // Set the sum_dims indexes to 0. - for &(dim, stride) in self.sum_dims_and_stride.iter() { + // Set the reduce_dims indexes to 0. + for &(dim, stride) in self.reduce_dims_and_stride.iter() { // The compiler is able to optimize the following in a single divmod op. let (pre, post) = (dst_index / stride, dst_index % stride); dst_index = (pre / dim) * stride + post; } - dst[dst_index] += src[src_index]; + dst[dst_index] = f(dst[dst_index], src[src_index]); } } } @@ -158,6 +163,31 @@ impl<'a> Map1 for Sum<'a> { } } +impl<'a> Map1 for Reduce<'a> { + #[inline(always)] + fn f(&self, src: &[T], src_l: &Layout) -> Result> { + match self.op { + ReduceOp::Min => { + let s = if src_l.shape().elem_count() != 0 { + src[src_l.start_offset()] + } else { + Err(Error::EmptyTensor { op: "min" }.bt())? + }; + self.fold_impl(src, src_l, s, |x, y| if x < y { x } else { y }) + } + ReduceOp::Max => { + let s = if src_l.shape().elem_count() != 0 { + src[src_l.start_offset()] + } else { + Err(Error::EmptyTensor { op: "max" }.bt())? + }; + self.fold_impl(src, src_l, s, |x, y| if x > y { x } else { y }) + } + ReduceOp::Sum => self.fold_impl(src, src_l, T::zero(), |x, y| x + y), + } + } +} + fn unary_map U>(vs: &[T], layout: &Layout, mut f: F) -> Vec { match layout.strided_blocks() { crate::StridedBlocks::SingleBlock { start_offset, len } => vs @@ -340,7 +370,7 @@ fn binary_map_vec T, FV: FnMut(&[T], &[T], &mut [T])> } (Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() { Some(ob) if ob.right_broadcast == 1 => { - let rhs = &rhs[ob.start..]; + let rhs = &rhs[ob.start..ob.start + ob.len]; let mut ys: Vec = Vec::with_capacity(el_count); let ys_to_set = ys.spare_capacity_mut(); let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) }; @@ -358,7 +388,7 @@ fn binary_map_vec T, FV: FnMut(&[T], &[T], &mut [T])> ys } Some(ob) => { - let rhs = &rhs[ob.start..]; + let rhs = &rhs[ob.start..ob.start + ob.len]; let mut ys = lhs[o_l1..o_l2].to_vec(); for idx_l in 0..ob.left_broadcast { let start = idx_l * ob.len * ob.right_broadcast; @@ -379,7 +409,7 @@ fn binary_map_vec T, FV: FnMut(&[T], &[T], &mut [T])> }, (None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() { Some(ob) if ob.right_broadcast == 1 => { - let lhs = &lhs[ob.start..]; + let lhs = &lhs[ob.start..ob.start + ob.len]; let mut ys: Vec = Vec::with_capacity(el_count); let ys_to_set = ys.spare_capacity_mut(); let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) }; @@ -397,7 +427,7 @@ fn binary_map_vec T, FV: FnMut(&[T], &[T], &mut [T])> ys } Some(ob) => { - let lhs = &lhs[ob.start..]; + let lhs = &lhs[ob.start..ob.start + ob.len]; let mut ys = rhs[o_r1..o_r2].to_vec(); for idx_l in 0..ob.left_broadcast { let start = idx_l * ob.len * ob.right_broadcast; @@ -1010,25 +1040,26 @@ impl BackendStorage for CpuStorage { } } - fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result { + fn reduce_op(&self, op: ReduceOp, layout: &Layout, reduce_dims: &[usize]) -> Result { let src_dims = layout.dims(); let mut dst_dims = src_dims.to_vec(); - for &sum_dim in sum_dims.iter() { - dst_dims[sum_dim] = 1; + for &dim in reduce_dims.iter() { + dst_dims[dim] = 1; } let dst_shape = Shape::from(dst_dims); - let mut sum_dims = sum_dims.to_vec(); - // Sort the sum_dims as they have to be processed from left to right when converting the + let mut reduce_dims = reduce_dims.to_vec(); + // Sort the reduce_dims as they have to be processed from left to right when converting the // indexes. - sum_dims.sort(); - let sum_dims_and_stride: Vec<_> = sum_dims + reduce_dims.sort(); + let reduce_dims_and_stride: Vec<_> = reduce_dims .iter() .map(|&d| (src_dims[d], src_dims[d + 1..].iter().product::())) .collect(); - Sum { + Reduce { dst_shape: &dst_shape, - sum_dims: &sum_dims, - sum_dims_and_stride, + reduce_dims: &reduce_dims, + reduce_dims_and_stride, + op, } .map(self, layout) } diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 74a3cf30..07d354b6 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -955,10 +955,21 @@ impl BackendStorage for CudaStorage { Ok(Self { slice, device }) } - fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result { - let device = self.device().clone(); - let slice = FastSum(sum_dims).map(&self.slice, &device, layout)?; - Ok(Self { slice, device }) + fn reduce_op( + &self, + op: crate::op::ReduceOp, + layout: &Layout, + sum_dims: &[usize], + ) -> Result { + match op { + crate::op::ReduceOp::Sum => { + let device = self.device().clone(); + let slice = FastSum(sum_dims).map(&self.slice, &device, layout)?; + Ok(Self { slice, device }) + } + crate::op::ReduceOp::Min => Err(CudaError::InternalError("TODO: implement min").into()), + crate::op::ReduceOp::Max => Err(CudaError::InternalError("TODO: implement max").into()), + } } fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> { diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index a9c11bf6..f7cf8ab8 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -40,7 +40,7 @@ impl crate::backend::BackendStorage for CudaStorage { Err(Error::NotCompiledWithCudaSupport) } - fn sum(&self, _: &Layout, _: &[usize]) -> Result { + fn reduce_op(&self, _: crate::op::ReduceOp, _: &Layout, _: &[usize]) -> Result { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index e354b239..4ec639db 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -79,6 +79,9 @@ pub enum Error { nth_shape: Shape, }, + #[error("empty tensor for {op}")] + EmptyTensor { op: &'static str }, + // === Device Errors === #[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")] DeviceMismatchBinaryOp { diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 07ee7670..c5ff8179 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -29,6 +29,8 @@ pub(crate) enum Op { add: f64, }, Sum(Tensor, Vec), + Max(Tensor, Vec), + Min(Tensor, Vec), ToDType(Tensor), Broadcast(Tensor), Exp(Tensor), @@ -354,3 +356,10 @@ impl UnaryOp for Relu { v } } + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ReduceOp { + Sum, + Min, + Max, +} diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 1531b212..e689905e 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -80,14 +80,19 @@ impl Storage { } } - pub(crate) fn sum(&self, layout: &Layout, s: &[usize]) -> Result { + pub(crate) fn reduce_op( + &self, + op: crate::op::ReduceOp, + layout: &Layout, + s: &[usize], + ) -> Result { match self { Storage::Cpu(storage) => { - let storage = storage.sum(layout, s)?; + let storage = storage.reduce_op(op, layout, s)?; Ok(Self::Cpu(storage)) } Self::Cuda(storage) => { - let storage = storage.sum(layout, s)?; + let storage = storage.reduce_op(op, layout, s)?; Ok(Self::Cuda(storage)) } } diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index a93514fc..276a522e 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1,6 +1,7 @@ use crate::backend::{BackendDevice, BackendStorage}; +use crate::op::{Op, ReduceOp}; use crate::shape::{Dim, Dims}; -use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape}; +use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape}; use std::sync::{Arc, RwLock}; /// Unique identifier for tensors. @@ -154,8 +155,14 @@ impl Tensor { device: &Device, is_variable: bool, ) -> Result { - let storage = device.ones(&crate::shape::SCALAR, dtype)?; - from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape) + if is_variable { + let shape = shape.into(); + let storage = device.ones(&shape, dtype)?; + Ok(from_storage(storage, shape, None, is_variable)) + } else { + let storage = device.ones(&crate::shape::SCALAR, dtype)?; + from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape) + } } /// Creates a new tensor filled with ones. @@ -192,8 +199,14 @@ impl Tensor { device: &Device, is_variable: bool, ) -> Result { - let storage = device.zeros(&crate::shape::SCALAR, dtype)?; - from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape) + if is_variable { + let shape = shape.into(); + let storage = device.zeros(&shape, dtype)?; + Ok(from_storage(storage, shape, None, is_variable)) + } else { + let storage = device.zeros(&crate::shape::SCALAR, dtype)?; + from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape) + } } /// Creates a new tensor filled with zeros. @@ -593,9 +606,77 @@ impl Tensor { } } - pub fn sum_impl(&self, sum_dims: D, keepdim: bool) -> Result { + fn squeeze_dims(self, dims: &[usize]) -> Result { + match dims { + [] => Ok(self), + [i] => self.squeeze(*i), + dims => { + let dims = self + .dims() + .iter() + .enumerate() + .filter_map(|(dim_idx, &v)| { + if dims.contains(&dim_idx) { + None + } else { + Some(v) + } + }) + .collect::>(); + self.reshape(dims) + } + } + } + + fn max_impl(&self, max_dims: D, keepdim: bool) -> Result { + let max_dims = max_dims.to_indexes(self.shape(), "max")?; + let storage = self + .storage() + .reduce_op(ReduceOp::Max, self.layout(), &max_dims)?; + let op = if self.track_op() { + Some(Op::Max(self.clone(), max_dims.to_vec())) + } else { + None + }; + let mut dims = self.dims().to_vec(); + for &max_dim in max_dims.iter() { + dims[max_dim] = 1 + } + let max = from_storage(storage, dims, op, false); + if keepdim { + Ok(max) + } else { + max.squeeze_dims(&max_dims) + } + } + + fn min_impl(&self, min_dims: D, keepdim: bool) -> Result { + let min_dims = min_dims.to_indexes(self.shape(), "min")?; + let storage = self + .storage() + .reduce_op(ReduceOp::Min, self.layout(), &min_dims)?; + let op = if self.track_op() { + Some(Op::Min(self.clone(), min_dims.to_vec())) + } else { + None + }; + let mut dims = self.dims().to_vec(); + for &min_dim in min_dims.iter() { + dims[min_dim] = 1 + } + let min = from_storage(storage, dims, op, false); + if keepdim { + Ok(min) + } else { + min.squeeze_dims(&min_dims) + } + } + + fn sum_impl(&self, sum_dims: D, keepdim: bool) -> Result { let sum_dims = sum_dims.to_indexes(self.shape(), "sum")?; - let storage = self.storage().sum(self.layout(), &sum_dims)?; + let storage = self + .storage() + .reduce_op(ReduceOp::Sum, self.layout(), &sum_dims)?; let op = if self.track_op() { Some(Op::Sum(self.clone(), sum_dims.to_vec())) } else { @@ -609,25 +690,7 @@ impl Tensor { if keepdim { Ok(sum) } else { - match sum_dims.as_slice() { - [] => Ok(sum), - [i] => sum.squeeze(*i), - sum_dims => { - let dims = sum - .dims() - .iter() - .enumerate() - .filter_map(|(dim_idx, &v)| { - if sum_dims.contains(&dim_idx) { - None - } else { - Some(v) - } - }) - .collect::>(); - sum.reshape(dims) - } - } + sum.squeeze_dims(&sum_dims) } } @@ -659,6 +722,32 @@ impl Tensor { self.sum_impl(sum_dims, false) } + pub fn max_keepdim(&self, max_dims: D) -> Result { + self.max_impl(max_dims, true) + } + + pub fn max(&self, max_dims: D) -> Result { + self.max_impl(max_dims, false) + } + + pub fn max_all(&self) -> Result { + let dims: Vec<_> = (0..self.rank()).collect(); + self.max(dims) + } + + pub fn min_keepdim(&self, min_dims: D) -> Result { + self.min_impl(min_dims, true) + } + + pub fn min(&self, min_dims: D) -> Result { + self.min_impl(min_dims, false) + } + + pub fn min_all(&self) -> Result { + let dims: Vec<_> = (0..self.rank()).collect(); + self.min(dims) + } + /// Applies a 1D convolution over the input tensor. pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result { let (c_out, c_in_k, k_size) = kernel.shape().r3()?; diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 114997b9..24435e81 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -21,7 +21,7 @@ intel-mkl-src = { workspace = true, optional = true } [dev-dependencies] anyhow = { workspace = true } -candle-hub = { path = "../candle-hub" } +hf-hub = { workspace = true} clap = { workspace = true } rand = { workspace = true } tokenizers = { workspace = true, features = ["onig"] } diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 8ef8b5ce..33f0a1fe 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -4,9 +4,9 @@ mod model; use anyhow::{anyhow, Error as E, Result}; use candle::Tensor; -use candle_hub::{api::sync::Api, Cache, Repo, RepoType}; use candle_nn::VarBuilder; use clap::Parser; +use hf_hub::{api::sync::Api, Cache, Repo, RepoType}; use model::{BertModel, Config, DTYPE}; use tokenizers::{PaddingParams, Tokenizer}; diff --git a/candle-examples/examples/falcon/main.rs b/candle-examples/examples/falcon/main.rs index 7d5eaa52..3a284c86 100644 --- a/candle-examples/examples/falcon/main.rs +++ b/candle-examples/examples/falcon/main.rs @@ -5,10 +5,10 @@ extern crate intel_mkl_src; use anyhow::{Error as E, Result}; use candle::{DType, Device, Tensor}; -use candle_hub::{api::sync::Api, Repo, RepoType}; use candle_nn::VarBuilder; use candle_transformers::generation::LogitsProcessor; use clap::Parser; +use hf_hub::{api::sync::Api, Repo, RepoType}; use tokenizers::Tokenizer; mod model; diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index aa02299d..40f1af06 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -16,9 +16,9 @@ use anyhow::{Error as E, Result}; use clap::Parser; use candle::{DType, Device, Tensor, D}; -use candle_hub::{api::sync::Api, Repo, RepoType}; use candle_nn::VarBuilder; use candle_transformers::generation::LogitsProcessor; +use hf_hub::{api::sync::Api, Repo, RepoType}; mod model; use model::{Config, Llama}; diff --git a/candle-examples/examples/simple-training/main.rs b/candle-examples/examples/simple-training/main.rs new file mode 100644 index 00000000..df67f741 --- /dev/null +++ b/candle-examples/examples/simple-training/main.rs @@ -0,0 +1,44 @@ +// This should rearch 91.5% accuracy. +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use anyhow::Result; +use candle::{DType, Var, D}; + +const IMAGE_DIM: usize = 784; +const LABELS: usize = 10; + +pub fn main() -> Result<()> { + let dev = candle::Device::cuda_if_available(0)?; + let m = candle_nn::vision::mnist::load_dir("data")?; + println!("train-images: {:?}", m.train_images.shape()); + println!("train-labels: {:?}", m.train_labels.shape()); + println!("test-images: {:?}", m.test_images.shape()); + println!("test-labels: {:?}", m.test_labels.shape()); + let ws = Var::zeros((IMAGE_DIM, LABELS), DType::F32, &dev)?; + let bs = Var::zeros(LABELS, DType::F32, &dev)?; + let sgd = candle_nn::SGD::new(&[&ws, &bs], 0.1); + for epoch in 1..200 { + let logits = m.train_images.matmul(&ws)?.broadcast_add(&bs)?; + let loss = logits.softmax(D::Minus1)?; + // TODO: log_softmax + let loss = loss.nll_loss(&m.train_labels); + sgd.backward_step(&loss)?; + + let _test_logits = m.test_images.matmul(&ws)?.broadcast_add(&bs)?; + /* TODO + let test_accuracy = test_logits + .argmax(Some(-1), false) + .eq_tensor(&m.test_labels) + .to_kind(Kind::Float) + .mean(Kind::Float) + .double_value(&[]); + */ + let test_accuracy = 0.; + println!( + "{epoch:4} train loss: {:8.5} test acc: {:5.2}%", + loss.to_scalar::()?, + 100. * test_accuracy + ) + } + Ok(()) +} diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index c8e42c72..d7b303cf 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -11,9 +11,9 @@ extern crate intel_mkl_src; use anyhow::{Error as E, Result}; use candle::{safetensors::Load, DType, Device, Tensor}; -use candle_hub::{api::sync::Api, Repo, RepoType}; use candle_nn::VarBuilder; use clap::Parser; +use hf_hub::{api::sync::Api, Repo, RepoType}; use rand::{distributions::Distribution, SeedableRng}; use tokenizers::Tokenizer; diff --git a/candle-hub/.gitignore b/candle-hub/.gitignore deleted file mode 100644 index 4fffb2f8..00000000 --- a/candle-hub/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -/target -/Cargo.lock diff --git a/candle-hub/Cargo.toml b/candle-hub/Cargo.toml deleted file mode 100644 index 2b091642..00000000 --- a/candle-hub/Cargo.toml +++ /dev/null @@ -1,29 +0,0 @@ -[package] -name = "candle-hub" -version = "0.1.0" -edition = "2021" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -dirs = "5.0.1" -rand = { workspace = true } -thiserror = { workspace = true } -futures = { workspace = true, optional = true } -reqwest = { workspace = true, optional = true, features = ["json"] } -tokio = { workspace = true, features = ["fs"], optional = true } -serde = { workspace = true, optional = true } -serde_json = { workspace = true, optional = true } -indicatif = { workspace = true, optional = true } -num_cpus = { workspace = true, optional = true } - -[dev-dependencies] -rand = { workspace = true } -sha256 = { workspace = true } -tokio = { workspace = true, features = ["macros"] } -tokio-test = { workspace = true } - -[features] -default = ["online"] -online = ["reqwest/blocking", "dep:serde", "dep:serde_json", "dep:indicatif", "dep:num_cpus"] -tokio = ["online", "dep:tokio", "dep:futures"] diff --git a/candle-hub/src/api/mod.rs b/candle-hub/src/api/mod.rs deleted file mode 100644 index 779dc4f9..00000000 --- a/candle-hub/src/api/mod.rs +++ /dev/null @@ -1,6 +0,0 @@ -/// The asynchronous version of the API -#[cfg(feature = "tokio")] -pub mod tokio; - -/// The synchronous version of the API -pub mod sync; diff --git a/candle-hub/src/api/sync.rs b/candle-hub/src/api/sync.rs deleted file mode 100644 index 6efdbff8..00000000 --- a/candle-hub/src/api/sync.rs +++ /dev/null @@ -1,686 +0,0 @@ -use crate::{Cache, Repo}; -use indicatif::{ProgressBar, ProgressStyle}; -use rand::{distributions::Alphanumeric, thread_rng, Rng}; -use reqwest::{ - blocking::Client, - header::{ - HeaderMap, HeaderName, HeaderValue, InvalidHeaderValue, ToStrError, AUTHORIZATION, - CONTENT_RANGE, LOCATION, RANGE, USER_AGENT, - }, - redirect::Policy, - Error as ReqwestError, -}; -use serde::Deserialize; -use std::io::{Seek, SeekFrom, Write}; -use std::num::ParseIntError; -use std::path::{Component, Path, PathBuf}; -use thiserror::Error; - -/// Current version (used in user-agent) -const VERSION: &str = env!("CARGO_PKG_VERSION"); -/// Current name (used in user-agent) -const NAME: &str = env!("CARGO_PKG_NAME"); - -#[derive(Debug, Error)] -/// All errors the API can throw -pub enum ApiError { - /// Api expects certain header to be present in the results to derive some information - #[error("Header {0} is missing")] - MissingHeader(HeaderName), - - /// The header exists, but the value is not conform to what the Api expects. - #[error("Header {0} is invalid")] - InvalidHeader(HeaderName), - - /// The value cannot be used as a header during request header construction - #[error("Invalid header value {0}")] - InvalidHeaderValue(#[from] InvalidHeaderValue), - - /// The header value is not valid utf-8 - #[error("header value is not a string")] - ToStr(#[from] ToStrError), - - /// Error in the request - #[error("request error: {0}")] - RequestError(#[from] ReqwestError), - - /// Error parsing some range value - #[error("Cannot parse int")] - ParseIntError(#[from] ParseIntError), - - /// I/O Error - #[error("I/O error {0}")] - IoError(#[from] std::io::Error), - - /// We tried to download chunk too many times - #[error("Too many retries: {0}")] - TooManyRetries(Box), -} - -/// Siblings are simplified file descriptions of remote files on the hub -#[derive(Debug, Clone, Deserialize, PartialEq)] -pub struct Siblings { - /// The path within the repo. - pub rfilename: String, -} - -/// The description of the repo given by the hub -#[derive(Debug, Clone, Deserialize, PartialEq)] -pub struct ModelInfo { - /// See [`Siblings`] - pub siblings: Vec, -} - -/// Helper to create [`Api`] with all the options. -pub struct ApiBuilder { - endpoint: String, - cache: Cache, - url_template: String, - token: Option, - chunk_size: usize, - parallel_failures: usize, - max_retries: usize, - progress: bool, -} - -impl Default for ApiBuilder { - fn default() -> Self { - Self::new() - } -} - -impl ApiBuilder { - /// Default api builder - /// ``` - /// use candle_hub::api::sync::ApiBuilder; - /// let api = ApiBuilder::new().build().unwrap(); - /// ``` - pub fn new() -> Self { - let cache = Cache::default(); - let mut token_filename = cache.path().clone(); - token_filename.push(".token"); - let token = match std::fs::read_to_string(token_filename) { - Ok(token_content) => { - let token_content = token_content.trim(); - if !token_content.is_empty() { - Some(token_content.to_string()) - } else { - None - } - } - Err(_) => None, - }; - - let progress = true; - - Self { - endpoint: "https://huggingface.co".to_string(), - url_template: "{endpoint}/{repo_id}/resolve/{revision}/{filename}".to_string(), - cache, - token, - chunk_size: 10_000_000, - parallel_failures: 0, - max_retries: 0, - progress, - } - } - - /// Wether to show a progressbar - pub fn with_progress(mut self, progress: bool) -> Self { - self.progress = progress; - self - } - - /// Changes the location of the cache directory. Defaults is `~/.cache/huggingface/`. - pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self { - self.cache = Cache::new(cache_dir); - self - } - - fn build_headers(&self) -> Result { - let mut headers = HeaderMap::new(); - let user_agent = format!("unkown/None; {NAME}/{VERSION}; rust/unknown"); - headers.insert(USER_AGENT, HeaderValue::from_str(&user_agent)?); - if let Some(token) = &self.token { - headers.insert( - AUTHORIZATION, - HeaderValue::from_str(&format!("Bearer {token}"))?, - ); - } - Ok(headers) - } - - /// Consumes the builder and buids the final [`Api`] - pub fn build(self) -> Result { - let headers = self.build_headers()?; - let client = Client::builder().default_headers(headers.clone()).build()?; - let no_redirect_client = Client::builder() - .redirect(Policy::none()) - .default_headers(headers) - .build()?; - Ok(Api { - endpoint: self.endpoint, - url_template: self.url_template, - cache: self.cache, - client, - - no_redirect_client, - chunk_size: self.chunk_size, - parallel_failures: self.parallel_failures, - max_retries: self.max_retries, - progress: self.progress, - }) - } -} - -#[derive(Debug)] -struct Metadata { - commit_hash: String, - etag: String, - size: usize, -} - -/// The actual Api used to interacto with the hub. -/// You can inspect repos with [`Api::info`] -/// or download files with [`Api::download`] -pub struct Api { - endpoint: String, - url_template: String, - cache: Cache, - client: Client, - no_redirect_client: Client, - chunk_size: usize, - parallel_failures: usize, - max_retries: usize, - progress: bool, -} - -fn temp_filename() -> PathBuf { - let s: String = rand::thread_rng() - .sample_iter(&Alphanumeric) - .take(7) - .map(char::from) - .collect(); - let mut path = std::env::temp_dir(); - path.push(s); - path -} - -fn make_relative(src: &Path, dst: &Path) -> PathBuf { - let path = src; - let base = dst; - - if path.is_absolute() != base.is_absolute() { - panic!("This function is made to look at absolute paths only"); - } - let mut ita = path.components(); - let mut itb = base.components(); - - loop { - match (ita.next(), itb.next()) { - (Some(a), Some(b)) if a == b => (), - (some_a, _) => { - // Ignoring b, because 1 component is the filename - // for which we don't need to go back up for relative - // filename to work. - let mut new_path = PathBuf::new(); - for _ in itb { - new_path.push(Component::ParentDir); - } - if let Some(a) = some_a { - new_path.push(a); - for comp in ita { - new_path.push(comp); - } - } - return new_path; - } - } - } -} - -fn symlink_or_rename(src: &Path, dst: &Path) -> Result<(), std::io::Error> { - if dst.exists() { - return Ok(()); - } - - let src = make_relative(src, dst); - #[cfg(target_os = "windows")] - std::os::windows::fs::symlink_file(src, dst)?; - - #[cfg(target_family = "unix")] - std::os::unix::fs::symlink(src, dst)?; - - #[cfg(not(any(target_family = "unix", target_os = "windows")))] - std::fs::rename(src, dst)?; - - Ok(()) -} - -fn jitter() -> usize { - thread_rng().gen_range(0..=500) -} - -fn exponential_backoff(base_wait_time: usize, n: usize, max: usize) -> usize { - (base_wait_time + n.pow(2) + jitter()).min(max) -} - -impl Api { - /// Creates a default Api, for Api options See [`ApiBuilder`] - pub fn new() -> Result { - ApiBuilder::new().build() - } - - /// Get the fully qualified URL of the remote filename - /// ``` - /// # use candle_hub::{api::sync::Api, Repo}; - /// let api = Api::new().unwrap(); - /// let repo = Repo::model("gpt2".to_string()); - /// let url = api.url(&repo, "model.safetensors"); - /// assert_eq!(url, "https://huggingface.co/gpt2/resolve/main/model.safetensors"); - /// ``` - pub fn url(&self, repo: &Repo, filename: &str) -> String { - let endpoint = &self.endpoint; - let revision = &repo.url_revision(); - self.url_template - .replace("{endpoint}", endpoint) - .replace("{repo_id}", &repo.url()) - .replace("{revision}", revision) - .replace("{filename}", filename) - } - - /// Get the underlying api client - /// Allows for lower level access - pub fn client(&self) -> &Client { - &self.client - } - - fn metadata(&self, url: &str) -> Result { - let response = self - .no_redirect_client - .get(url) - .header(RANGE, "bytes=0-0") - .send()?; - let response = response.error_for_status()?; - let headers = response.headers(); - let header_commit = HeaderName::from_static("x-repo-commit"); - let header_linked_etag = HeaderName::from_static("x-linked-etag"); - let header_etag = HeaderName::from_static("etag"); - - let etag = match headers.get(&header_linked_etag) { - Some(etag) => etag, - None => headers - .get(&header_etag) - .ok_or(ApiError::MissingHeader(header_etag))?, - }; - // Cleaning extra quotes - let etag = etag.to_str()?.to_string().replace('"', ""); - let commit_hash = headers - .get(&header_commit) - .ok_or(ApiError::MissingHeader(header_commit))? - .to_str()? - .to_string(); - - // The response was redirected o S3 most likely which will - // know about the size of the file - let response = if response.status().is_redirection() { - self.client - .get(headers.get(LOCATION).unwrap().to_str()?.to_string()) - .header(RANGE, "bytes=0-0") - .send()? - } else { - response - }; - let headers = response.headers(); - let content_range = headers - .get(CONTENT_RANGE) - .ok_or(ApiError::MissingHeader(CONTENT_RANGE))? - .to_str()?; - - let size = content_range - .split('/') - .last() - .ok_or(ApiError::InvalidHeader(CONTENT_RANGE))? - .parse()?; - Ok(Metadata { - commit_hash, - etag, - size, - }) - } - - fn download_tempfile( - &self, - url: &str, - length: usize, - progressbar: Option, - ) -> Result { - let filename = temp_filename(); - - // Create the file and set everything properly - std::fs::File::create(&filename)?.set_len(length as u64)?; - - let chunk_size = self.chunk_size; - - let n_chunks = (length + chunk_size - 1) / chunk_size; - let n_threads = num_cpus::get(); - let chunks_per_thread = (n_chunks + n_threads - 1) / n_threads; - let handles = (0..n_threads).map(|thread_id| { - let url = url.to_string(); - let filename = filename.clone(); - let client = self.client.clone(); - let parallel_failures = self.parallel_failures; - let max_retries = self.max_retries; - let progress = progressbar.clone(); - std::thread::spawn(move || { - for chunk_id in chunks_per_thread * thread_id - ..std::cmp::min(chunks_per_thread * (thread_id + 1), n_chunks) - { - let start = chunk_id * chunk_size; - let stop = std::cmp::min(start + chunk_size - 1, length); - let mut chunk = Self::download_chunk(&client, &url, &filename, start, stop); - let mut i = 0; - if parallel_failures > 0 { - while let Err(dlerr) = chunk { - let wait_time = exponential_backoff(300, i, 10_000); - std::thread::sleep(std::time::Duration::from_millis(wait_time as u64)); - - chunk = Self::download_chunk(&client, &url, &filename, start, stop); - i += 1; - if i > max_retries { - return Err(ApiError::TooManyRetries(dlerr.into())); - } - } - } - if let Some(p) = &progress { - p.inc((stop - start) as u64); - } - chunk? - } - Ok(()) - }) - }); - - let results: Result, ApiError> = - handles.into_iter().flat_map(|h| h.join()).collect(); - - results?; - if let Some(p) = progressbar { - p.finish() - } - Ok(filename) - } - - fn download_chunk( - client: &Client, - url: &str, - filename: &PathBuf, - start: usize, - stop: usize, - ) -> Result<(), ApiError> { - // Process each socket concurrently. - let range = format!("bytes={start}-{stop}"); - let mut file = std::fs::OpenOptions::new().write(true).open(filename)?; - file.seek(SeekFrom::Start(start as u64))?; - let response = client - .get(url) - .header(RANGE, range) - .send()? - .error_for_status()?; - let content = response.bytes()?; - file.write_all(&content)?; - Ok(()) - } - - /// This will attempt the fetch the file locally first, then [`Api.download`] - /// if the file is not present. - /// ```no_run - /// use candle_hub::{api::sync::ApiBuilder, Repo}; - /// let api = ApiBuilder::new().build().unwrap(); - /// let repo = Repo::model("gpt2".to_string()); - /// let local_filename = api.get(&repo, "model.safetensors").unwrap(); - pub fn get(&self, repo: &Repo, filename: &str) -> Result { - if let Some(path) = self.cache.get(repo, filename) { - Ok(path) - } else { - self.download(repo, filename) - } - } - - /// Downloads a remote file (if not already present) into the cache directory - /// to be used locally. - /// This functions require internet access to verify if new versions of the file - /// exist, even if a file is already on disk at location. - /// ```no_run - /// # use candle_hub::{api::sync::ApiBuilder, Repo}; - /// let api = ApiBuilder::new().build().unwrap(); - /// let repo = Repo::model("gpt2".to_string()); - /// let local_filename = api.download(&repo, "model.safetensors").unwrap(); - /// ``` - pub fn download(&self, repo: &Repo, filename: &str) -> Result { - let url = self.url(repo, filename); - let metadata = self.metadata(&url)?; - - let blob_path = self.cache.blob_path(repo, &metadata.etag); - std::fs::create_dir_all(blob_path.parent().unwrap())?; - - let progressbar = if self.progress { - let progress = ProgressBar::new(metadata.size as u64); - progress.set_style( - ProgressStyle::with_template( - "{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})", - ) - .unwrap(), // .progress_chars("━ "), - ); - let maxlength = 30; - let message = if filename.len() > maxlength { - format!("..{}", &filename[filename.len() - maxlength..]) - } else { - filename.to_string() - }; - progress.set_message(message); - Some(progress) - } else { - None - }; - - let tmp_filename = self.download_tempfile(&url, metadata.size, progressbar)?; - - if std::fs::rename(&tmp_filename, &blob_path).is_err() { - // Renaming may fail if locations are different mount points - std::fs::File::create(&blob_path)?; - std::fs::copy(tmp_filename, &blob_path)?; - } - - let mut pointer_path = self.cache.pointer_path(repo, &metadata.commit_hash); - pointer_path.push(filename); - std::fs::create_dir_all(pointer_path.parent().unwrap()).ok(); - - symlink_or_rename(&blob_path, &pointer_path)?; - self.cache.create_ref(repo, &metadata.commit_hash)?; - - Ok(pointer_path) - } - - /// Get information about the Repo - /// ``` - /// use candle_hub::{api::sync::Api, Repo}; - /// let api = Api::new().unwrap(); - /// let repo = Repo::model("gpt2".to_string()); - /// api.info(&repo); - /// ``` - pub fn info(&self, repo: &Repo) -> Result { - let url = format!("{}/api/{}", self.endpoint, repo.api_url()); - let response = self.client.get(url).send()?; - let response = response.error_for_status()?; - - let model_info = response.json()?; - - Ok(model_info) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::RepoType; - use rand::{distributions::Alphanumeric, Rng}; - use sha256::try_digest; - - struct TempDir { - path: PathBuf, - } - - impl TempDir { - pub fn new() -> Self { - let s: String = rand::thread_rng() - .sample_iter(&Alphanumeric) - .take(7) - .map(char::from) - .collect(); - let mut path = std::env::temp_dir(); - path.push(s); - std::fs::create_dir(&path).unwrap(); - Self { path } - } - } - - impl Drop for TempDir { - fn drop(&mut self) { - std::fs::remove_dir_all(&self.path).unwrap() - } - } - - #[test] - fn simple() { - let tmp = TempDir::new(); - let api = ApiBuilder::new() - .with_progress(false) - .with_cache_dir(tmp.path.clone()) - .build() - .unwrap(); - let repo = Repo::new("julien-c/dummy-unknown".to_string(), RepoType::Model); - let downloaded_path = api.download(&repo, "config.json").unwrap(); - assert!(downloaded_path.exists()); - let val = try_digest(&*downloaded_path).unwrap(); - assert_eq!( - val, - "b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32" - ); - - // Make sure the file is now seeable without connection - let cache_path = api.cache.get(&repo, "config.json").unwrap(); - assert_eq!(cache_path, downloaded_path); - } - - #[test] - fn dataset() { - let tmp = TempDir::new(); - let api = ApiBuilder::new() - .with_progress(false) - .with_cache_dir(tmp.path.clone()) - .build() - .unwrap(); - let repo = Repo::with_revision( - "wikitext".to_string(), - RepoType::Dataset, - "refs/convert/parquet".to_string(), - ); - let downloaded_path = api - .download(&repo, "wikitext-103-v1/wikitext-test.parquet") - .unwrap(); - assert!(downloaded_path.exists()); - let val = try_digest(&*downloaded_path).unwrap(); - assert_eq!( - val, - "59ce09415ad8aa45a9e34f88cec2548aeb9de9a73fcda9f6b33a86a065f32b90" - ) - } - - #[test] - fn info() { - let tmp = TempDir::new(); - let api = ApiBuilder::new() - .with_progress(false) - .with_cache_dir(tmp.path.clone()) - .build() - .unwrap(); - let repo = Repo::with_revision( - "wikitext".to_string(), - RepoType::Dataset, - "refs/convert/parquet".to_string(), - ); - let model_info = api.info(&repo).unwrap(); - assert_eq!( - model_info, - ModelInfo { - siblings: vec![ - Siblings { - rfilename: ".gitattributes".to_string() - }, - Siblings { - rfilename: "wikitext-103-raw-v1/wikitext-test.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-103-raw-v1/wikitext-train-00000-of-00002.parquet" - .to_string() - }, - Siblings { - rfilename: "wikitext-103-raw-v1/wikitext-train-00001-of-00002.parquet" - .to_string() - }, - Siblings { - rfilename: "wikitext-103-raw-v1/wikitext-validation.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-103-v1/test/index.duckdb".to_string() - }, - Siblings { - rfilename: "wikitext-103-v1/validation/index.duckdb".to_string() - }, - Siblings { - rfilename: "wikitext-103-v1/wikitext-test.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-103-v1/wikitext-train-00000-of-00002.parquet" - .to_string() - }, - Siblings { - rfilename: "wikitext-103-v1/wikitext-train-00001-of-00002.parquet" - .to_string() - }, - Siblings { - rfilename: "wikitext-103-v1/wikitext-validation.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-2-raw-v1/test/index.duckdb".to_string() - }, - Siblings { - rfilename: "wikitext-2-raw-v1/train/index.duckdb".to_string() - }, - Siblings { - rfilename: "wikitext-2-raw-v1/validation/index.duckdb".to_string() - }, - Siblings { - rfilename: "wikitext-2-raw-v1/wikitext-test.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-2-raw-v1/wikitext-train.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-2-raw-v1/wikitext-validation.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-2-v1/wikitext-test.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-2-v1/wikitext-train.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-2-v1/wikitext-validation.parquet".to_string() - } - ], - } - ) - } -} diff --git a/candle-hub/src/api/tokio.rs b/candle-hub/src/api/tokio.rs deleted file mode 100644 index dc8f682e..00000000 --- a/candle-hub/src/api/tokio.rs +++ /dev/null @@ -1,723 +0,0 @@ -use crate::{Cache, Repo}; -use indicatif::{ProgressBar, ProgressStyle}; -use rand::{distributions::Alphanumeric, thread_rng, Rng}; -use reqwest::{ - header::{ - HeaderMap, HeaderName, HeaderValue, InvalidHeaderValue, ToStrError, AUTHORIZATION, - CONTENT_RANGE, LOCATION, RANGE, USER_AGENT, - }, - redirect::Policy, - Client, Error as ReqwestError, -}; -use serde::Deserialize; -use std::num::ParseIntError; -use std::path::{Component, Path, PathBuf}; -use std::sync::Arc; -use thiserror::Error; -use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom}; -use tokio::sync::{AcquireError, Semaphore, TryAcquireError}; - -/// Current version (used in user-agent) -const VERSION: &str = env!("CARGO_PKG_VERSION"); -/// Current name (used in user-agent) -const NAME: &str = env!("CARGO_PKG_NAME"); - -#[derive(Debug, Error)] -/// All errors the API can throw -pub enum ApiError { - /// Api expects certain header to be present in the results to derive some information - #[error("Header {0} is missing")] - MissingHeader(HeaderName), - - /// The header exists, but the value is not conform to what the Api expects. - #[error("Header {0} is invalid")] - InvalidHeader(HeaderName), - - /// The value cannot be used as a header during request header construction - #[error("Invalid header value {0}")] - InvalidHeaderValue(#[from] InvalidHeaderValue), - - /// The header value is not valid utf-8 - #[error("header value is not a string")] - ToStr(#[from] ToStrError), - - /// Error in the request - #[error("request error: {0}")] - RequestError(#[from] ReqwestError), - - /// Error parsing some range value - #[error("Cannot parse int")] - ParseIntError(#[from] ParseIntError), - - /// I/O Error - #[error("I/O error {0}")] - IoError(#[from] std::io::Error), - - /// We tried to download chunk too many times - #[error("Too many retries: {0}")] - TooManyRetries(Box), - - /// Semaphore cannot be acquired - #[error("Try acquire: {0}")] - TryAcquireError(#[from] TryAcquireError), - - /// Semaphore cannot be acquired - #[error("Acquire: {0}")] - AcquireError(#[from] AcquireError), - // /// Semaphore cannot be acquired - // #[error("Invalid Response: {0:?}")] - // InvalidResponse(Response), -} - -/// Siblings are simplified file descriptions of remote files on the hub -#[derive(Debug, Clone, Deserialize, PartialEq)] -pub struct Siblings { - /// The path within the repo. - pub rfilename: String, -} - -/// The description of the repo given by the hub -#[derive(Debug, Clone, Deserialize, PartialEq)] -pub struct ModelInfo { - /// See [`Siblings`] - pub siblings: Vec, -} - -/// Helper to create [`Api`] with all the options. -pub struct ApiBuilder { - endpoint: String, - cache: Cache, - url_template: String, - token: Option, - max_files: usize, - chunk_size: usize, - parallel_failures: usize, - max_retries: usize, - progress: bool, -} - -impl Default for ApiBuilder { - fn default() -> Self { - Self::new() - } -} - -impl ApiBuilder { - /// Default api builder - /// ``` - /// use candle_hub::api::tokio::ApiBuilder; - /// let api = ApiBuilder::new().build().unwrap(); - /// ``` - pub fn new() -> Self { - let cache = Cache::default(); - let mut token_filename = cache.path().clone(); - token_filename.push(".token"); - let token = match std::fs::read_to_string(token_filename) { - Ok(token_content) => { - let token_content = token_content.trim(); - if !token_content.is_empty() { - Some(token_content.to_string()) - } else { - None - } - } - Err(_) => None, - }; - - let progress = true; - - Self { - endpoint: "https://huggingface.co".to_string(), - url_template: "{endpoint}/{repo_id}/resolve/{revision}/{filename}".to_string(), - cache, - token, - max_files: num_cpus::get(), - chunk_size: 10_000_000, - parallel_failures: 0, - max_retries: 0, - progress, - } - } - - /// Wether to show a progressbar - pub fn with_progress(mut self, progress: bool) -> Self { - self.progress = progress; - self - } - - /// Changes the location of the cache directory. Defaults is `~/.cache/huggingface/`. - pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self { - self.cache = Cache::new(cache_dir); - self - } - - fn build_headers(&self) -> Result { - let mut headers = HeaderMap::new(); - let user_agent = format!("unkown/None; {NAME}/{VERSION}; rust/unknown"); - headers.insert(USER_AGENT, HeaderValue::from_str(&user_agent)?); - if let Some(token) = &self.token { - headers.insert( - AUTHORIZATION, - HeaderValue::from_str(&format!("Bearer {token}"))?, - ); - } - Ok(headers) - } - - /// Consumes the builder and buids the final [`Api`] - pub fn build(self) -> Result { - let headers = self.build_headers()?; - let client = Client::builder().default_headers(headers.clone()).build()?; - let no_redirect_client = Client::builder() - .redirect(Policy::none()) - .default_headers(headers) - .build()?; - Ok(Api { - endpoint: self.endpoint, - url_template: self.url_template, - cache: self.cache, - client, - - no_redirect_client, - max_files: self.max_files, - chunk_size: self.chunk_size, - parallel_failures: self.parallel_failures, - max_retries: self.max_retries, - progress: self.progress, - }) - } -} - -#[derive(Debug)] -struct Metadata { - commit_hash: String, - etag: String, - size: usize, -} - -/// The actual Api used to interacto with the hub. -/// You can inspect repos with [`Api::info`] -/// or download files with [`Api::download`] -pub struct Api { - endpoint: String, - url_template: String, - cache: Cache, - client: Client, - no_redirect_client: Client, - max_files: usize, - chunk_size: usize, - parallel_failures: usize, - max_retries: usize, - progress: bool, -} - -fn temp_filename() -> PathBuf { - let s: String = rand::thread_rng() - .sample_iter(&Alphanumeric) - .take(7) - .map(char::from) - .collect(); - let mut path = std::env::temp_dir(); - path.push(s); - path -} - -fn make_relative(src: &Path, dst: &Path) -> PathBuf { - let path = src; - let base = dst; - - if path.is_absolute() != base.is_absolute() { - panic!("This function is made to look at absolute paths only"); - } - let mut ita = path.components(); - let mut itb = base.components(); - - loop { - match (ita.next(), itb.next()) { - (Some(a), Some(b)) if a == b => (), - (some_a, _) => { - // Ignoring b, because 1 component is the filename - // for which we don't need to go back up for relative - // filename to work. - let mut new_path = PathBuf::new(); - for _ in itb { - new_path.push(Component::ParentDir); - } - if let Some(a) = some_a { - new_path.push(a); - for comp in ita { - new_path.push(comp); - } - } - return new_path; - } - } - } -} - -fn symlink_or_rename(src: &Path, dst: &Path) -> Result<(), std::io::Error> { - if dst.exists() { - return Ok(()); - } - - let src = make_relative(src, dst); - #[cfg(target_os = "windows")] - std::os::windows::fs::symlink_file(src, dst)?; - - #[cfg(target_family = "unix")] - std::os::unix::fs::symlink(src, dst)?; - - #[cfg(not(any(target_family = "unix", target_os = "windows")))] - std::fs::rename(src, dst)?; - - Ok(()) -} - -fn jitter() -> usize { - thread_rng().gen_range(0..=500) -} - -fn exponential_backoff(base_wait_time: usize, n: usize, max: usize) -> usize { - (base_wait_time + n.pow(2) + jitter()).min(max) -} - -impl Api { - /// Creates a default Api, for Api options See [`ApiBuilder`] - pub fn new() -> Result { - ApiBuilder::new().build() - } - - /// Get the fully qualified URL of the remote filename - /// ``` - /// # use candle_hub::{api::tokio::Api, Repo}; - /// let api = Api::new().unwrap(); - /// let repo = Repo::model("gpt2".to_string()); - /// let url = api.url(&repo, "model.safetensors"); - /// assert_eq!(url, "https://huggingface.co/gpt2/resolve/main/model.safetensors"); - /// ``` - pub fn url(&self, repo: &Repo, filename: &str) -> String { - let endpoint = &self.endpoint; - let revision = &repo.url_revision(); - self.url_template - .replace("{endpoint}", endpoint) - .replace("{repo_id}", &repo.url()) - .replace("{revision}", revision) - .replace("{filename}", filename) - } - - /// Get the underlying api client - /// Allows for lower level access - pub fn client(&self) -> &Client { - &self.client - } - - async fn metadata(&self, url: &str) -> Result { - let response = self - .no_redirect_client - .get(url) - .header(RANGE, "bytes=0-0") - .send() - .await?; - let response = response.error_for_status()?; - let headers = response.headers(); - let header_commit = HeaderName::from_static("x-repo-commit"); - let header_linked_etag = HeaderName::from_static("x-linked-etag"); - let header_etag = HeaderName::from_static("etag"); - - let etag = match headers.get(&header_linked_etag) { - Some(etag) => etag, - None => headers - .get(&header_etag) - .ok_or(ApiError::MissingHeader(header_etag))?, - }; - // Cleaning extra quotes - let etag = etag.to_str()?.to_string().replace('"', ""); - let commit_hash = headers - .get(&header_commit) - .ok_or(ApiError::MissingHeader(header_commit))? - .to_str()? - .to_string(); - - // The response was redirected o S3 most likely which will - // know about the size of the file - let response = if response.status().is_redirection() { - self.client - .get(headers.get(LOCATION).unwrap().to_str()?.to_string()) - .header(RANGE, "bytes=0-0") - .send() - .await? - } else { - response - }; - let headers = response.headers(); - let content_range = headers - .get(CONTENT_RANGE) - .ok_or(ApiError::MissingHeader(CONTENT_RANGE))? - .to_str()?; - - let size = content_range - .split('/') - .last() - .ok_or(ApiError::InvalidHeader(CONTENT_RANGE))? - .parse()?; - Ok(Metadata { - commit_hash, - etag, - size, - }) - } - - async fn download_tempfile( - &self, - url: &str, - length: usize, - progressbar: Option, - ) -> Result { - let mut handles = vec![]; - let semaphore = Arc::new(Semaphore::new(self.max_files)); - let parallel_failures_semaphore = Arc::new(Semaphore::new(self.parallel_failures)); - let filename = temp_filename(); - - // Create the file and set everything properly - tokio::fs::File::create(&filename) - .await? - .set_len(length as u64) - .await?; - - let chunk_size = self.chunk_size; - for start in (0..length).step_by(chunk_size) { - let url = url.to_string(); - let filename = filename.clone(); - let client = self.client.clone(); - - let stop = std::cmp::min(start + chunk_size - 1, length); - let permit = semaphore.clone().acquire_owned().await?; - let parallel_failures = self.parallel_failures; - let max_retries = self.max_retries; - let parallel_failures_semaphore = parallel_failures_semaphore.clone(); - let progress = progressbar.clone(); - handles.push(tokio::spawn(async move { - let mut chunk = Self::download_chunk(&client, &url, &filename, start, stop).await; - let mut i = 0; - if parallel_failures > 0 { - while let Err(dlerr) = chunk { - let parallel_failure_permit = - parallel_failures_semaphore.clone().try_acquire_owned()?; - - let wait_time = exponential_backoff(300, i, 10_000); - tokio::time::sleep(tokio::time::Duration::from_millis(wait_time as u64)) - .await; - - chunk = Self::download_chunk(&client, &url, &filename, start, stop).await; - i += 1; - if i > max_retries { - return Err(ApiError::TooManyRetries(dlerr.into())); - } - drop(parallel_failure_permit); - } - } - drop(permit); - if let Some(p) = progress { - p.inc((stop - start) as u64); - } - chunk - })); - } - - // Output the chained result - let results: Vec, tokio::task::JoinError>> = - futures::future::join_all(handles).await; - let results: Result<(), ApiError> = results.into_iter().flatten().collect(); - results?; - if let Some(p) = progressbar { - p.finish() - } - Ok(filename) - } - - async fn download_chunk( - client: &reqwest::Client, - url: &str, - filename: &PathBuf, - start: usize, - stop: usize, - ) -> Result<(), ApiError> { - // Process each socket concurrently. - let range = format!("bytes={start}-{stop}"); - let mut file = tokio::fs::OpenOptions::new() - .write(true) - .open(filename) - .await?; - file.seek(SeekFrom::Start(start as u64)).await?; - let response = client - .get(url) - .header(RANGE, range) - .send() - .await? - .error_for_status()?; - let content = response.bytes().await?; - file.write_all(&content).await?; - Ok(()) - } - - /// This will attempt the fetch the file locally first, then [`Api.download`] - /// if the file is not present. - /// ```no_run - /// # use candle_hub::{api::tokio::ApiBuilder, Repo}; - /// # tokio_test::block_on(async { - /// let api = ApiBuilder::new().build().unwrap(); - /// let repo = Repo::model("gpt2".to_string()); - /// let local_filename = api.get(&repo, "model.safetensors").await.unwrap(); - /// # }) - pub async fn get(&self, repo: &Repo, filename: &str) -> Result { - if let Some(path) = self.cache.get(repo, filename) { - Ok(path) - } else { - self.download(repo, filename).await - } - } - - /// Downloads a remote file (if not already present) into the cache directory - /// to be used locally. - /// This functions require internet access to verify if new versions of the file - /// exist, even if a file is already on disk at location. - /// ```no_run - /// # use candle_hub::{api::tokio::ApiBuilder, Repo}; - /// # tokio_test::block_on(async { - /// let api = ApiBuilder::new().build().unwrap(); - /// let repo = Repo::model("gpt2".to_string()); - /// let local_filename = api.download(&repo, "model.safetensors").await.unwrap(); - /// # }) - /// ``` - pub async fn download(&self, repo: &Repo, filename: &str) -> Result { - let url = self.url(repo, filename); - let metadata = self.metadata(&url).await?; - - let blob_path = self.cache.blob_path(repo, &metadata.etag); - std::fs::create_dir_all(blob_path.parent().unwrap())?; - - let progressbar = if self.progress { - let progress = ProgressBar::new(metadata.size as u64); - progress.set_style( - ProgressStyle::with_template( - "{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})", - ) - .unwrap(), // .progress_chars("━ "), - ); - let maxlength = 30; - let message = if filename.len() > maxlength { - format!("..{}", &filename[filename.len() - maxlength..]) - } else { - filename.to_string() - }; - progress.set_message(message); - Some(progress) - } else { - None - }; - - let tmp_filename = self - .download_tempfile(&url, metadata.size, progressbar) - .await?; - - if tokio::fs::rename(&tmp_filename, &blob_path).await.is_err() { - // Renaming may fail if locations are different mount points - std::fs::File::create(&blob_path)?; - tokio::fs::copy(tmp_filename, &blob_path).await?; - } - - let mut pointer_path = self.cache.pointer_path(repo, &metadata.commit_hash); - pointer_path.push(filename); - std::fs::create_dir_all(pointer_path.parent().unwrap()).ok(); - - symlink_or_rename(&blob_path, &pointer_path)?; - self.cache.create_ref(repo, &metadata.commit_hash)?; - - Ok(pointer_path) - } - - /// Get information about the Repo - /// ``` - /// # use candle_hub::{api::tokio::Api, Repo}; - /// # tokio_test::block_on(async { - /// let api = Api::new().unwrap(); - /// let repo = Repo::model("gpt2".to_string()); - /// api.info(&repo); - /// # }) - /// ``` - pub async fn info(&self, repo: &Repo) -> Result { - let url = format!("{}/api/{}", self.endpoint, repo.api_url()); - let response = self.client.get(url).send().await?; - let response = response.error_for_status()?; - - let model_info = response.json().await?; - - Ok(model_info) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::RepoType; - use rand::{distributions::Alphanumeric, Rng}; - use sha256::try_digest; - - struct TempDir { - path: PathBuf, - } - - impl TempDir { - pub fn new() -> Self { - let s: String = rand::thread_rng() - .sample_iter(&Alphanumeric) - .take(7) - .map(char::from) - .collect(); - let mut path = std::env::temp_dir(); - path.push(s); - std::fs::create_dir(&path).unwrap(); - Self { path } - } - } - - impl Drop for TempDir { - fn drop(&mut self) { - std::fs::remove_dir_all(&self.path).unwrap() - } - } - - #[tokio::test] - async fn simple() { - let tmp = TempDir::new(); - let api = ApiBuilder::new() - .with_progress(false) - .with_cache_dir(tmp.path.clone()) - .build() - .unwrap(); - let repo = Repo::new("julien-c/dummy-unknown".to_string(), RepoType::Model); - let downloaded_path = api.download(&repo, "config.json").await.unwrap(); - assert!(downloaded_path.exists()); - let val = try_digest(&*downloaded_path).unwrap(); - assert_eq!( - val, - "b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32" - ); - - // Make sure the file is now seeable without connection - let cache_path = api.cache.get(&repo, "config.json").unwrap(); - assert_eq!(cache_path, downloaded_path); - } - - #[tokio::test] - async fn dataset() { - let tmp = TempDir::new(); - let api = ApiBuilder::new() - .with_progress(false) - .with_cache_dir(tmp.path.clone()) - .build() - .unwrap(); - let repo = Repo::with_revision( - "wikitext".to_string(), - RepoType::Dataset, - "refs/convert/parquet".to_string(), - ); - let downloaded_path = api - .download(&repo, "wikitext-103-v1/wikitext-test.parquet") - .await - .unwrap(); - assert!(downloaded_path.exists()); - let val = try_digest(&*downloaded_path).unwrap(); - assert_eq!( - val, - "59ce09415ad8aa45a9e34f88cec2548aeb9de9a73fcda9f6b33a86a065f32b90" - ) - } - - #[tokio::test] - async fn info() { - let tmp = TempDir::new(); - let api = ApiBuilder::new() - .with_progress(false) - .with_cache_dir(tmp.path.clone()) - .build() - .unwrap(); - let repo = Repo::with_revision( - "wikitext".to_string(), - RepoType::Dataset, - "refs/convert/parquet".to_string(), - ); - let model_info = api.info(&repo).await.unwrap(); - assert_eq!( - model_info, - ModelInfo { - siblings: vec![ - Siblings { - rfilename: ".gitattributes".to_string() - }, - Siblings { - rfilename: "wikitext-103-raw-v1/wikitext-test.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-103-raw-v1/wikitext-train-00000-of-00002.parquet" - .to_string() - }, - Siblings { - rfilename: "wikitext-103-raw-v1/wikitext-train-00001-of-00002.parquet" - .to_string() - }, - Siblings { - rfilename: "wikitext-103-raw-v1/wikitext-validation.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-103-v1/test/index.duckdb".to_string() - }, - Siblings { - rfilename: "wikitext-103-v1/validation/index.duckdb".to_string() - }, - Siblings { - rfilename: "wikitext-103-v1/wikitext-test.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-103-v1/wikitext-train-00000-of-00002.parquet" - .to_string() - }, - Siblings { - rfilename: "wikitext-103-v1/wikitext-train-00001-of-00002.parquet" - .to_string() - }, - Siblings { - rfilename: "wikitext-103-v1/wikitext-validation.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-2-raw-v1/test/index.duckdb".to_string() - }, - Siblings { - rfilename: "wikitext-2-raw-v1/train/index.duckdb".to_string() - }, - Siblings { - rfilename: "wikitext-2-raw-v1/validation/index.duckdb".to_string() - }, - Siblings { - rfilename: "wikitext-2-raw-v1/wikitext-test.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-2-raw-v1/wikitext-train.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-2-raw-v1/wikitext-validation.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-2-v1/wikitext-test.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-2-v1/wikitext-train.parquet".to_string() - }, - Siblings { - rfilename: "wikitext-2-v1/wikitext-validation.parquet".to_string() - } - ], - } - ) - } -} diff --git a/candle-hub/src/lib.rs b/candle-hub/src/lib.rs deleted file mode 100644 index 0de2006a..00000000 --- a/candle-hub/src/lib.rs +++ /dev/null @@ -1,197 +0,0 @@ -#![deny(missing_docs)] -//! This crates aims to emulate and be compatible with the -//! [huggingface_hub](https://github.com/huggingface/huggingface_hub/) python package. -//! -//! compatible means the Api should reuse the same files skipping downloads if -//! they are already present and whenever this crate downloads or modifies this cache -//! it should be consistent with [huggingface_hub](https://github.com/huggingface/huggingface_hub/) -//! -//! At this time only a limited subset of the functionality is present, the goal is to add new -//! features over time -use std::io::Write; -use std::path::PathBuf; - -/// The actual Api to interact with the hub. -#[cfg(feature = "online")] -pub mod api; - -/// The type of repo to interact with -#[derive(Debug, Clone, Copy)] -pub enum RepoType { - /// This is a model, usually it consists of weight files and some configuration - /// files - Model, - /// This is a dataset, usually contains data within parquet files - Dataset, - /// This is a space, usually a demo showcashing a given model or dataset - Space, -} - -/// A local struct used to fetch information from the cache folder. -pub struct Cache { - path: PathBuf, -} - -impl Cache { - /// Creates a new cache object location - pub fn new(path: PathBuf) -> Self { - Self { path } - } - - /// Creates a new cache object location - pub fn path(&self) -> &PathBuf { - &self.path - } - - /// This will get the location of the file within the cache for the remote - /// `filename`. Will return `None` if file is not already present in cache. - pub fn get(&self, repo: &Repo, filename: &str) -> Option { - let mut commit_path = self.path.clone(); - commit_path.push(repo.folder_name()); - commit_path.push("refs"); - commit_path.push(repo.revision()); - let commit_hash = std::fs::read_to_string(commit_path).ok()?; - let mut pointer_path = self.pointer_path(repo, &commit_hash); - pointer_path.push(filename); - if pointer_path.exists() { - Some(pointer_path) - } else { - None - } - } - - /// Creates a reference in the cache directory that points branches to the correct - /// commits within the blobs. - pub fn create_ref(&self, repo: &Repo, commit_hash: &str) -> Result<(), std::io::Error> { - let mut ref_path = self.path.clone(); - ref_path.push(repo.folder_name()); - ref_path.push("refs"); - ref_path.push(repo.revision()); - // Needs to be done like this because revision might contain `/` creating subfolders here. - std::fs::create_dir_all(ref_path.parent().unwrap())?; - let mut file1 = std::fs::OpenOptions::new() - .write(true) - .create(true) - .open(&ref_path)?; - file1.write_all(commit_hash.trim().as_bytes())?; - Ok(()) - } - - #[cfg(feature = "online")] - pub(crate) fn blob_path(&self, repo: &Repo, etag: &str) -> PathBuf { - let mut blob_path = self.path.clone(); - blob_path.push(repo.folder_name()); - blob_path.push("blobs"); - blob_path.push(etag); - blob_path - } - - pub(crate) fn pointer_path(&self, repo: &Repo, commit_hash: &str) -> PathBuf { - let mut pointer_path = self.path.clone(); - pointer_path.push(repo.folder_name()); - pointer_path.push("snapshots"); - pointer_path.push(commit_hash); - pointer_path - } -} - -impl Default for Cache { - fn default() -> Self { - let path = match std::env::var("HF_HOME") { - Ok(home) => home.into(), - Err(_) => { - let mut cache = dirs::home_dir().expect("Cache directory cannot be found"); - cache.push(".cache"); - cache.push("huggingface"); - cache.push("hub"); - cache - } - }; - Self::new(path) - } -} - -/// The representation of a repo on the hub. -#[allow(dead_code)] // Repo type unused in offline mode -pub struct Repo { - repo_id: String, - repo_type: RepoType, - revision: String, -} - -impl Repo { - /// Repo with the default branch ("main"). - pub fn new(repo_id: String, repo_type: RepoType) -> Self { - Self::with_revision(repo_id, repo_type, "main".to_string()) - } - - /// fully qualified Repo - pub fn with_revision(repo_id: String, repo_type: RepoType, revision: String) -> Self { - Self { - repo_id, - repo_type, - revision, - } - } - - /// Shortcut for [`Repo::new`] with [`RepoType::Model`] - pub fn model(repo_id: String) -> Self { - Self::new(repo_id, RepoType::Model) - } - - /// Shortcut for [`Repo::new`] with [`RepoType::Dataset`] - pub fn dataset(repo_id: String) -> Self { - Self::new(repo_id, RepoType::Dataset) - } - - /// Shortcut for [`Repo::new`] with [`RepoType::Space`] - pub fn space(repo_id: String) -> Self { - Self::new(repo_id, RepoType::Space) - } - - /// The normalized folder nameof the repo within the cache directory - pub fn folder_name(&self) -> String { - let prefix = match self.repo_type { - RepoType::Model => "models", - RepoType::Dataset => "datasets", - RepoType::Space => "spaces", - }; - format!("{prefix}--{}", self.repo_id).replace('/', "--") - } - - /// The revision - pub fn revision(&self) -> &str { - &self.revision - } - - /// The actual URL part of the repo - #[cfg(feature = "online")] - pub fn url(&self) -> String { - match self.repo_type { - RepoType::Model => self.repo_id.to_string(), - RepoType::Dataset => { - format!("datasets/{}", self.repo_id) - } - RepoType::Space => { - format!("spaces/{}", self.repo_id) - } - } - } - - /// Revision needs to be url escaped before being used in a URL - #[cfg(feature = "online")] - pub fn url_revision(&self) -> String { - self.revision.replace('/', "%2F") - } - - /// Used to compute the repo's url part when accessing the metadata of the repo - #[cfg(feature = "online")] - pub fn api_url(&self) -> String { - let prefix = match self.repo_type { - RepoType::Model => "models", - RepoType::Dataset => "datasets", - RepoType::Space => "spaces", - }; - format!("{prefix}/{}/revision/{}", self.repo_id, self.url_revision()) - } -} diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml index 01b41763..46847703 100644 --- a/candle-transformers/Cargo.toml +++ b/candle-transformers/Cargo.toml @@ -12,7 +12,7 @@ readme = "README.md" [dependencies] candle = { path = "../candle-core" } -candle-hub = { path = "../candle-hub" } +hf-hub = { workspace = true} candle-nn = { path = "../candle-nn" } intel-mkl-src = { workspace = true, optional = true, features = ["mkl-dynamic-lp64-iomp"]} tokenizers = { workspace = true, features = ["onig"] }