mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Merge branch 'main' into remove_wrapper
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,6 +1,7 @@
|
||||
# Generated by Cargo
|
||||
# will have compiled files and executables
|
||||
debug/
|
||||
data/
|
||||
dist/
|
||||
target/
|
||||
|
||||
|
12
Cargo.toml
12
Cargo.toml
@ -2,7 +2,6 @@
|
||||
members = [
|
||||
"candle-core",
|
||||
"candle-examples",
|
||||
"candle-hub",
|
||||
"candle-nn",
|
||||
"candle-pyo3",
|
||||
"candle-transformers",
|
||||
@ -19,29 +18,24 @@ clap = { version = "4.2.4", features = ["derive"] }
|
||||
# Re-enable this once 0.9.13 as been released as it would include the cublas-f16 changes
|
||||
# cudarc = { version = "0.9.13", optional = true, features = ["f16"] }
|
||||
cudarc = { git = "https://github.com/LaurentMazare/cudarc.git", branch = "cublas-bf16", features = ["f16"] }
|
||||
futures = "0.3.28"
|
||||
# TODO: Switch back to the official gemm implementation once the following are available.
|
||||
# https://github.com/sarah-ek/gemm/pull/8.
|
||||
# https://github.com/sarah-ek/gemm/pull/9.
|
||||
gemm = { git = "https://github.com/LaurentMazare/gemm.git", branch = "f16-vec-plus-wasm-simd" }
|
||||
hf-hub = "0.1.0"
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
indicatif = "0.17.5"
|
||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-dynamic-lp64-iomp"] }
|
||||
intel-mkl-src = { version = "0.8.1", features = ["mkl-static-ilp64-iomp"] }
|
||||
libc = { version = "0.2.147" }
|
||||
log = "0.4"
|
||||
memmap2 = "0.7.1"
|
||||
num_cpus = "1.15.0"
|
||||
num-traits = "0.2.15"
|
||||
rand = "0.8.5"
|
||||
reqwest = "0.11.18"
|
||||
safetensors = "0.3.1"
|
||||
serde = { version = "1.0.166", features = ["derive"] }
|
||||
serde = { version = "1.0.171", features = ["derive"] }
|
||||
serde_json = "1.0.99"
|
||||
sha256 = "=1.1.4"
|
||||
thiserror = "1"
|
||||
tokenizers = { version = "0.13.3", default-features = false }
|
||||
tokio = "1.28.2"
|
||||
tokio-test = "0.4.2"
|
||||
tracing = "0.1.37"
|
||||
tracing-chrome = "0.7.1"
|
||||
tracing-subscriber = "0.3.7"
|
||||
|
@ -16,7 +16,7 @@ pub(crate) trait BackendStorage: Sized {
|
||||
|
||||
fn elu(&self, _: &Layout, _: f64) -> Result<Self>;
|
||||
|
||||
fn sum(&self, _: &Layout, _: &[usize]) -> Result<Self>;
|
||||
fn reduce_op(&self, _: crate::op::ReduceOp, _: &Layout, _: &[usize]) -> Result<Self>;
|
||||
|
||||
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()>;
|
||||
|
||||
|
@ -67,6 +67,8 @@ impl Tensor {
|
||||
Op::Reshape(node)
|
||||
| Op::Broadcast(node)
|
||||
| Op::Sum(node, _)
|
||||
| Op::Max(node, _)
|
||||
| Op::Min(node, _)
|
||||
| Op::ToDType(node)
|
||||
| Op::ToDevice(node)
|
||||
| Op::Transpose(node, _, _)
|
||||
@ -203,6 +205,12 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.broadcast_add(&grad)?
|
||||
}
|
||||
Op::Max(_args, _sum_dims) => {
|
||||
return Err(Error::BackwardNotSupported { op: "max" })
|
||||
}
|
||||
Op::Min(_args, _sum_dims) => {
|
||||
return Err(Error::BackwardNotSupported { op: "min" })
|
||||
}
|
||||
Op::ToDType(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
|
||||
|
@ -1,5 +1,5 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{BinaryOp, UnaryOp};
|
||||
use crate::op::{BinaryOp, ReduceOp, UnaryOp};
|
||||
use crate::{DType, Error, Layout, Result, Shape, WithDType};
|
||||
use half::{bf16, f16};
|
||||
|
||||
@ -93,64 +93,69 @@ impl<'a> Map2 for WCond<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
struct Sum<'a> {
|
||||
struct Reduce<'a> {
|
||||
dst_shape: &'a Shape,
|
||||
sum_dims: &'a [usize],
|
||||
sum_dims_and_stride: Vec<(usize, usize)>,
|
||||
reduce_dims: &'a [usize],
|
||||
reduce_dims_and_stride: Vec<(usize, usize)>,
|
||||
op: ReduceOp,
|
||||
}
|
||||
|
||||
impl<'a> Map1 for Sum<'a> {
|
||||
impl<'a> Reduce<'a> {
|
||||
#[inline(always)]
|
||||
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||
let mut dst = vec![T::zero(); self.dst_shape.elem_count()];
|
||||
fn fold_impl<T, F>(&self, src: &[T], src_l: &Layout, start_elt: T, f: F) -> Result<Vec<T>>
|
||||
where
|
||||
T: Clone + Copy,
|
||||
F: Fn(T, T) -> T,
|
||||
{
|
||||
let mut dst = vec![start_elt; self.dst_shape.elem_count()];
|
||||
match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => {
|
||||
let src = &src[o1..o2];
|
||||
// Handle the case where we sum over the last dimensions separately as it is
|
||||
// Handle the case where we reduce over the last dimensions separately as it is
|
||||
// fairly common and easy to optimize. This rely on the layout being contiguous!
|
||||
// sum_dims is sorted, check if it is ranging from a to n-1.
|
||||
let sum_over_last_dims = self
|
||||
.sum_dims
|
||||
// reduce_dims is sorted, check if it is ranging from a to n-1.
|
||||
let reduce_over_last_dims = self
|
||||
.reduce_dims
|
||||
.iter()
|
||||
.rev()
|
||||
.enumerate()
|
||||
.all(|(i, &v)| v == src_l.shape().rank() - 1 - i);
|
||||
if sum_over_last_dims {
|
||||
let sum_sz = self
|
||||
.sum_dims_and_stride
|
||||
if reduce_over_last_dims {
|
||||
let reduce_sz = self
|
||||
.reduce_dims_and_stride
|
||||
.iter()
|
||||
.map(|(u, _)| u)
|
||||
.product::<usize>();
|
||||
let mut src_i = 0;
|
||||
for dst_v in dst.iter_mut() {
|
||||
for &s in src[src_i..src_i + sum_sz].iter() {
|
||||
*dst_v += s
|
||||
for &s in src[src_i..src_i + reduce_sz].iter() {
|
||||
*dst_v = f(*dst_v, s)
|
||||
}
|
||||
src_i += sum_sz
|
||||
src_i += reduce_sz
|
||||
}
|
||||
return Ok(dst);
|
||||
};
|
||||
for (unstr_index, &src) in src.iter().enumerate() {
|
||||
let mut dst_index = unstr_index;
|
||||
// Set the sum_dims indexes to 0.
|
||||
for &(dim, stride) in self.sum_dims_and_stride.iter() {
|
||||
// Set the reduce_dims indexes to 0.
|
||||
for &(dim, stride) in self.reduce_dims_and_stride.iter() {
|
||||
// The compiler is able to optimize the following in a single divmod op.
|
||||
let (pre, post) = (dst_index / stride, dst_index % stride);
|
||||
dst_index = (pre / dim) * stride + post;
|
||||
}
|
||||
dst[dst_index] += src;
|
||||
dst[dst_index] = f(dst[dst_index], src);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
for (unstr_index, src_index) in src_l.strided_index().enumerate() {
|
||||
let mut dst_index = unstr_index;
|
||||
// Set the sum_dims indexes to 0.
|
||||
for &(dim, stride) in self.sum_dims_and_stride.iter() {
|
||||
// Set the reduce_dims indexes to 0.
|
||||
for &(dim, stride) in self.reduce_dims_and_stride.iter() {
|
||||
// The compiler is able to optimize the following in a single divmod op.
|
||||
let (pre, post) = (dst_index / stride, dst_index % stride);
|
||||
dst_index = (pre / dim) * stride + post;
|
||||
}
|
||||
dst[dst_index] += src[src_index];
|
||||
dst[dst_index] = f(dst[dst_index], src[src_index]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -158,6 +163,31 @@ impl<'a> Map1 for Sum<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Map1 for Reduce<'a> {
|
||||
#[inline(always)]
|
||||
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||
match self.op {
|
||||
ReduceOp::Min => {
|
||||
let s = if src_l.shape().elem_count() != 0 {
|
||||
src[src_l.start_offset()]
|
||||
} else {
|
||||
Err(Error::EmptyTensor { op: "min" }.bt())?
|
||||
};
|
||||
self.fold_impl(src, src_l, s, |x, y| if x < y { x } else { y })
|
||||
}
|
||||
ReduceOp::Max => {
|
||||
let s = if src_l.shape().elem_count() != 0 {
|
||||
src[src_l.start_offset()]
|
||||
} else {
|
||||
Err(Error::EmptyTensor { op: "max" }.bt())?
|
||||
};
|
||||
self.fold_impl(src, src_l, s, |x, y| if x > y { x } else { y })
|
||||
}
|
||||
ReduceOp::Sum => self.fold_impl(src, src_l, T::zero(), |x, y| x + y),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(vs: &[T], layout: &Layout, mut f: F) -> Vec<U> {
|
||||
match layout.strided_blocks() {
|
||||
crate::StridedBlocks::SingleBlock { start_offset, len } => vs
|
||||
@ -340,7 +370,7 @@ fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>
|
||||
}
|
||||
(Some((o_l1, o_l2)), None) => match rhs_l.offsets_b() {
|
||||
Some(ob) if ob.right_broadcast == 1 => {
|
||||
let rhs = &rhs[ob.start..];
|
||||
let rhs = &rhs[ob.start..ob.start + ob.len];
|
||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
||||
@ -358,7 +388,7 @@ fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>
|
||||
ys
|
||||
}
|
||||
Some(ob) => {
|
||||
let rhs = &rhs[ob.start..];
|
||||
let rhs = &rhs[ob.start..ob.start + ob.len];
|
||||
let mut ys = lhs[o_l1..o_l2].to_vec();
|
||||
for idx_l in 0..ob.left_broadcast {
|
||||
let start = idx_l * ob.len * ob.right_broadcast;
|
||||
@ -379,7 +409,7 @@ fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>
|
||||
},
|
||||
(None, Some((o_r1, o_r2))) => match lhs_l.offsets_b() {
|
||||
Some(ob) if ob.right_broadcast == 1 => {
|
||||
let lhs = &lhs[ob.start..];
|
||||
let lhs = &lhs[ob.start..ob.start + ob.len];
|
||||
let mut ys: Vec<T> = Vec::with_capacity(el_count);
|
||||
let ys_to_set = ys.spare_capacity_mut();
|
||||
let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) };
|
||||
@ -397,7 +427,7 @@ fn binary_map_vec<T: Copy, F: FnMut(T, T) -> T, FV: FnMut(&[T], &[T], &mut [T])>
|
||||
ys
|
||||
}
|
||||
Some(ob) => {
|
||||
let lhs = &lhs[ob.start..];
|
||||
let lhs = &lhs[ob.start..ob.start + ob.len];
|
||||
let mut ys = rhs[o_r1..o_r2].to_vec();
|
||||
for idx_l in 0..ob.left_broadcast {
|
||||
let start = idx_l * ob.len * ob.right_broadcast;
|
||||
@ -1010,25 +1040,26 @@ impl BackendStorage for CpuStorage {
|
||||
}
|
||||
}
|
||||
|
||||
fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||
fn reduce_op(&self, op: ReduceOp, layout: &Layout, reduce_dims: &[usize]) -> Result<Self> {
|
||||
let src_dims = layout.dims();
|
||||
let mut dst_dims = src_dims.to_vec();
|
||||
for &sum_dim in sum_dims.iter() {
|
||||
dst_dims[sum_dim] = 1;
|
||||
for &dim in reduce_dims.iter() {
|
||||
dst_dims[dim] = 1;
|
||||
}
|
||||
let dst_shape = Shape::from(dst_dims);
|
||||
let mut sum_dims = sum_dims.to_vec();
|
||||
// Sort the sum_dims as they have to be processed from left to right when converting the
|
||||
let mut reduce_dims = reduce_dims.to_vec();
|
||||
// Sort the reduce_dims as they have to be processed from left to right when converting the
|
||||
// indexes.
|
||||
sum_dims.sort();
|
||||
let sum_dims_and_stride: Vec<_> = sum_dims
|
||||
reduce_dims.sort();
|
||||
let reduce_dims_and_stride: Vec<_> = reduce_dims
|
||||
.iter()
|
||||
.map(|&d| (src_dims[d], src_dims[d + 1..].iter().product::<usize>()))
|
||||
.collect();
|
||||
Sum {
|
||||
Reduce {
|
||||
dst_shape: &dst_shape,
|
||||
sum_dims: &sum_dims,
|
||||
sum_dims_and_stride,
|
||||
reduce_dims: &reduce_dims,
|
||||
reduce_dims_and_stride,
|
||||
op,
|
||||
}
|
||||
.map(self, layout)
|
||||
}
|
||||
|
@ -955,11 +955,22 @@ impl BackendStorage for CudaStorage {
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||
fn reduce_op(
|
||||
&self,
|
||||
op: crate::op::ReduceOp,
|
||||
layout: &Layout,
|
||||
sum_dims: &[usize],
|
||||
) -> Result<Self> {
|
||||
match op {
|
||||
crate::op::ReduceOp::Sum => {
|
||||
let device = self.device().clone();
|
||||
let slice = FastSum(sum_dims).map(&self.slice, &device, layout)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
crate::op::ReduceOp::Min => Err(CudaError::InternalError("TODO: implement min").into()),
|
||||
crate::op::ReduceOp::Max => Err(CudaError::InternalError("TODO: implement max").into()),
|
||||
}
|
||||
}
|
||||
|
||||
fn divide_by_sum_over_dim(&mut self, _: &Shape, _: usize) -> Result<()> {
|
||||
Err(CudaError::InternalError("TODO: implement divide_by_sum_over_dim").into())
|
||||
|
@ -40,7 +40,7 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn sum(&self, _: &Layout, _: &[usize]) -> Result<Self> {
|
||||
fn reduce_op(&self, _: crate::op::ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
|
@ -79,6 +79,9 @@ pub enum Error {
|
||||
nth_shape: Shape,
|
||||
},
|
||||
|
||||
#[error("empty tensor for {op}")]
|
||||
EmptyTensor { op: &'static str },
|
||||
|
||||
// === Device Errors ===
|
||||
#[error("device mismatch in {op}, lhs: {lhs:?}, rhs: {rhs:?}")]
|
||||
DeviceMismatchBinaryOp {
|
||||
|
@ -29,6 +29,8 @@ pub(crate) enum Op {
|
||||
add: f64,
|
||||
},
|
||||
Sum(Tensor, Vec<usize>),
|
||||
Max(Tensor, Vec<usize>),
|
||||
Min(Tensor, Vec<usize>),
|
||||
ToDType(Tensor),
|
||||
Broadcast(Tensor),
|
||||
Exp(Tensor),
|
||||
@ -354,3 +356,10 @@ impl UnaryOp for Relu {
|
||||
v
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ReduceOp {
|
||||
Sum,
|
||||
Min,
|
||||
Max,
|
||||
}
|
||||
|
@ -80,14 +80,19 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn sum(&self, layout: &Layout, s: &[usize]) -> Result<Self> {
|
||||
pub(crate) fn reduce_op(
|
||||
&self,
|
||||
op: crate::op::ReduceOp,
|
||||
layout: &Layout,
|
||||
s: &[usize],
|
||||
) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
let storage = storage.sum(layout, s)?;
|
||||
let storage = storage.reduce_op(op, layout, s)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.sum(layout, s)?;
|
||||
let storage = storage.reduce_op(op, layout, s)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{Op, ReduceOp};
|
||||
use crate::shape::{Dim, Dims};
|
||||
use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
/// Unique identifier for tensors.
|
||||
@ -154,9 +155,15 @@ impl Tensor {
|
||||
device: &Device,
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
if is_variable {
|
||||
let shape = shape.into();
|
||||
let storage = device.ones(&shape, dtype)?;
|
||||
Ok(from_storage(storage, shape, None, is_variable))
|
||||
} else {
|
||||
let storage = device.ones(&crate::shape::SCALAR, dtype)?;
|
||||
from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape)
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new tensor filled with ones.
|
||||
///
|
||||
@ -192,9 +199,15 @@ impl Tensor {
|
||||
device: &Device,
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
if is_variable {
|
||||
let shape = shape.into();
|
||||
let storage = device.zeros(&shape, dtype)?;
|
||||
Ok(from_storage(storage, shape, None, is_variable))
|
||||
} else {
|
||||
let storage = device.zeros(&crate::shape::SCALAR, dtype)?;
|
||||
from_storage(storage, crate::shape::SCALAR, None, is_variable).broadcast_as(shape)
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new tensor filled with zeros.
|
||||
///
|
||||
@ -593,9 +606,77 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sum_impl<D: Dims>(&self, sum_dims: D, keepdim: bool) -> Result<Self> {
|
||||
fn squeeze_dims(self, dims: &[usize]) -> Result<Self> {
|
||||
match dims {
|
||||
[] => Ok(self),
|
||||
[i] => self.squeeze(*i),
|
||||
dims => {
|
||||
let dims = self
|
||||
.dims()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(dim_idx, &v)| {
|
||||
if dims.contains(&dim_idx) {
|
||||
None
|
||||
} else {
|
||||
Some(v)
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
self.reshape(dims)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn max_impl<D: Dims>(&self, max_dims: D, keepdim: bool) -> Result<Self> {
|
||||
let max_dims = max_dims.to_indexes(self.shape(), "max")?;
|
||||
let storage = self
|
||||
.storage()
|
||||
.reduce_op(ReduceOp::Max, self.layout(), &max_dims)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Max(self.clone(), max_dims.to_vec()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut dims = self.dims().to_vec();
|
||||
for &max_dim in max_dims.iter() {
|
||||
dims[max_dim] = 1
|
||||
}
|
||||
let max = from_storage(storage, dims, op, false);
|
||||
if keepdim {
|
||||
Ok(max)
|
||||
} else {
|
||||
max.squeeze_dims(&max_dims)
|
||||
}
|
||||
}
|
||||
|
||||
fn min_impl<D: Dims>(&self, min_dims: D, keepdim: bool) -> Result<Self> {
|
||||
let min_dims = min_dims.to_indexes(self.shape(), "min")?;
|
||||
let storage = self
|
||||
.storage()
|
||||
.reduce_op(ReduceOp::Min, self.layout(), &min_dims)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Min(self.clone(), min_dims.to_vec()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut dims = self.dims().to_vec();
|
||||
for &min_dim in min_dims.iter() {
|
||||
dims[min_dim] = 1
|
||||
}
|
||||
let min = from_storage(storage, dims, op, false);
|
||||
if keepdim {
|
||||
Ok(min)
|
||||
} else {
|
||||
min.squeeze_dims(&min_dims)
|
||||
}
|
||||
}
|
||||
|
||||
fn sum_impl<D: Dims>(&self, sum_dims: D, keepdim: bool) -> Result<Self> {
|
||||
let sum_dims = sum_dims.to_indexes(self.shape(), "sum")?;
|
||||
let storage = self.storage().sum(self.layout(), &sum_dims)?;
|
||||
let storage = self
|
||||
.storage()
|
||||
.reduce_op(ReduceOp::Sum, self.layout(), &sum_dims)?;
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Sum(self.clone(), sum_dims.to_vec()))
|
||||
} else {
|
||||
@ -609,25 +690,7 @@ impl Tensor {
|
||||
if keepdim {
|
||||
Ok(sum)
|
||||
} else {
|
||||
match sum_dims.as_slice() {
|
||||
[] => Ok(sum),
|
||||
[i] => sum.squeeze(*i),
|
||||
sum_dims => {
|
||||
let dims = sum
|
||||
.dims()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(dim_idx, &v)| {
|
||||
if sum_dims.contains(&dim_idx) {
|
||||
None
|
||||
} else {
|
||||
Some(v)
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
sum.reshape(dims)
|
||||
}
|
||||
}
|
||||
sum.squeeze_dims(&sum_dims)
|
||||
}
|
||||
}
|
||||
|
||||
@ -659,6 +722,32 @@ impl Tensor {
|
||||
self.sum_impl(sum_dims, false)
|
||||
}
|
||||
|
||||
pub fn max_keepdim<D: Dims>(&self, max_dims: D) -> Result<Self> {
|
||||
self.max_impl(max_dims, true)
|
||||
}
|
||||
|
||||
pub fn max<D: Dims>(&self, max_dims: D) -> Result<Self> {
|
||||
self.max_impl(max_dims, false)
|
||||
}
|
||||
|
||||
pub fn max_all(&self) -> Result<Tensor> {
|
||||
let dims: Vec<_> = (0..self.rank()).collect();
|
||||
self.max(dims)
|
||||
}
|
||||
|
||||
pub fn min_keepdim<D: Dims>(&self, min_dims: D) -> Result<Self> {
|
||||
self.min_impl(min_dims, true)
|
||||
}
|
||||
|
||||
pub fn min<D: Dims>(&self, min_dims: D) -> Result<Self> {
|
||||
self.min_impl(min_dims, false)
|
||||
}
|
||||
|
||||
pub fn min_all(&self) -> Result<Tensor> {
|
||||
let dims: Vec<_> = (0..self.rank()).collect();
|
||||
self.min(dims)
|
||||
}
|
||||
|
||||
/// Applies a 1D convolution over the input tensor.
|
||||
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
|
||||
let (c_out, c_in_k, k_size) = kernel.shape().r3()?;
|
||||
|
@ -21,7 +21,7 @@ intel-mkl-src = { workspace = true, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
candle-hub = { path = "../candle-hub" }
|
||||
hf-hub = { workspace = true}
|
||||
clap = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
|
@ -4,9 +4,9 @@ mod model;
|
||||
|
||||
use anyhow::{anyhow, Error as E, Result};
|
||||
use candle::Tensor;
|
||||
use candle_hub::{api::sync::Api, Cache, Repo, RepoType};
|
||||
use candle_nn::VarBuilder;
|
||||
use clap::Parser;
|
||||
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
|
||||
use model::{BertModel, Config, DTYPE};
|
||||
use tokenizers::{PaddingParams, Tokenizer};
|
||||
|
||||
|
@ -5,10 +5,10 @@ extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_hub::{api::sync::Api, Repo, RepoType};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use clap::Parser;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
mod model;
|
||||
|
@ -16,9 +16,9 @@ use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle::{DType, Device, Tensor, D};
|
||||
use candle_hub::{api::sync::Api, Repo, RepoType};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
|
||||
mod model;
|
||||
use model::{Config, Llama};
|
||||
|
44
candle-examples/examples/simple-training/main.rs
Normal file
44
candle-examples/examples/simple-training/main.rs
Normal file
@ -0,0 +1,44 @@
|
||||
// This should rearch 91.5% accuracy.
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle::{DType, Var, D};
|
||||
|
||||
const IMAGE_DIM: usize = 784;
|
||||
const LABELS: usize = 10;
|
||||
|
||||
pub fn main() -> Result<()> {
|
||||
let dev = candle::Device::cuda_if_available(0)?;
|
||||
let m = candle_nn::vision::mnist::load_dir("data")?;
|
||||
println!("train-images: {:?}", m.train_images.shape());
|
||||
println!("train-labels: {:?}", m.train_labels.shape());
|
||||
println!("test-images: {:?}", m.test_images.shape());
|
||||
println!("test-labels: {:?}", m.test_labels.shape());
|
||||
let ws = Var::zeros((IMAGE_DIM, LABELS), DType::F32, &dev)?;
|
||||
let bs = Var::zeros(LABELS, DType::F32, &dev)?;
|
||||
let sgd = candle_nn::SGD::new(&[&ws, &bs], 0.1);
|
||||
for epoch in 1..200 {
|
||||
let logits = m.train_images.matmul(&ws)?.broadcast_add(&bs)?;
|
||||
let loss = logits.softmax(D::Minus1)?;
|
||||
// TODO: log_softmax + let loss = loss.nll_loss(&m.train_labels);
|
||||
sgd.backward_step(&loss)?;
|
||||
|
||||
let _test_logits = m.test_images.matmul(&ws)?.broadcast_add(&bs)?;
|
||||
/* TODO
|
||||
let test_accuracy = test_logits
|
||||
.argmax(Some(-1), false)
|
||||
.eq_tensor(&m.test_labels)
|
||||
.to_kind(Kind::Float)
|
||||
.mean(Kind::Float)
|
||||
.double_value(&[]);
|
||||
*/
|
||||
let test_accuracy = 0.;
|
||||
println!(
|
||||
"{epoch:4} train loss: {:8.5} test acc: {:5.2}%",
|
||||
loss.to_scalar::<f32>()?,
|
||||
100. * test_accuracy
|
||||
)
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -11,9 +11,9 @@ extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{safetensors::Load, DType, Device, Tensor};
|
||||
use candle_hub::{api::sync::Api, Repo, RepoType};
|
||||
use candle_nn::VarBuilder;
|
||||
use clap::Parser;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
|
2
candle-hub/.gitignore
vendored
2
candle-hub/.gitignore
vendored
@ -1,2 +0,0 @@
|
||||
/target
|
||||
/Cargo.lock
|
@ -1,29 +0,0 @@
|
||||
[package]
|
||||
name = "candle-hub"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
dirs = "5.0.1"
|
||||
rand = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
futures = { workspace = true, optional = true }
|
||||
reqwest = { workspace = true, optional = true, features = ["json"] }
|
||||
tokio = { workspace = true, features = ["fs"], optional = true }
|
||||
serde = { workspace = true, optional = true }
|
||||
serde_json = { workspace = true, optional = true }
|
||||
indicatif = { workspace = true, optional = true }
|
||||
num_cpus = { workspace = true, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
rand = { workspace = true }
|
||||
sha256 = { workspace = true }
|
||||
tokio = { workspace = true, features = ["macros"] }
|
||||
tokio-test = { workspace = true }
|
||||
|
||||
[features]
|
||||
default = ["online"]
|
||||
online = ["reqwest/blocking", "dep:serde", "dep:serde_json", "dep:indicatif", "dep:num_cpus"]
|
||||
tokio = ["online", "dep:tokio", "dep:futures"]
|
@ -1,6 +0,0 @@
|
||||
/// The asynchronous version of the API
|
||||
#[cfg(feature = "tokio")]
|
||||
pub mod tokio;
|
||||
|
||||
/// The synchronous version of the API
|
||||
pub mod sync;
|
@ -1,686 +0,0 @@
|
||||
use crate::{Cache, Repo};
|
||||
use indicatif::{ProgressBar, ProgressStyle};
|
||||
use rand::{distributions::Alphanumeric, thread_rng, Rng};
|
||||
use reqwest::{
|
||||
blocking::Client,
|
||||
header::{
|
||||
HeaderMap, HeaderName, HeaderValue, InvalidHeaderValue, ToStrError, AUTHORIZATION,
|
||||
CONTENT_RANGE, LOCATION, RANGE, USER_AGENT,
|
||||
},
|
||||
redirect::Policy,
|
||||
Error as ReqwestError,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use std::io::{Seek, SeekFrom, Write};
|
||||
use std::num::ParseIntError;
|
||||
use std::path::{Component, Path, PathBuf};
|
||||
use thiserror::Error;
|
||||
|
||||
/// Current version (used in user-agent)
|
||||
const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
/// Current name (used in user-agent)
|
||||
const NAME: &str = env!("CARGO_PKG_NAME");
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
/// All errors the API can throw
|
||||
pub enum ApiError {
|
||||
/// Api expects certain header to be present in the results to derive some information
|
||||
#[error("Header {0} is missing")]
|
||||
MissingHeader(HeaderName),
|
||||
|
||||
/// The header exists, but the value is not conform to what the Api expects.
|
||||
#[error("Header {0} is invalid")]
|
||||
InvalidHeader(HeaderName),
|
||||
|
||||
/// The value cannot be used as a header during request header construction
|
||||
#[error("Invalid header value {0}")]
|
||||
InvalidHeaderValue(#[from] InvalidHeaderValue),
|
||||
|
||||
/// The header value is not valid utf-8
|
||||
#[error("header value is not a string")]
|
||||
ToStr(#[from] ToStrError),
|
||||
|
||||
/// Error in the request
|
||||
#[error("request error: {0}")]
|
||||
RequestError(#[from] ReqwestError),
|
||||
|
||||
/// Error parsing some range value
|
||||
#[error("Cannot parse int")]
|
||||
ParseIntError(#[from] ParseIntError),
|
||||
|
||||
/// I/O Error
|
||||
#[error("I/O error {0}")]
|
||||
IoError(#[from] std::io::Error),
|
||||
|
||||
/// We tried to download chunk too many times
|
||||
#[error("Too many retries: {0}")]
|
||||
TooManyRetries(Box<ApiError>),
|
||||
}
|
||||
|
||||
/// Siblings are simplified file descriptions of remote files on the hub
|
||||
#[derive(Debug, Clone, Deserialize, PartialEq)]
|
||||
pub struct Siblings {
|
||||
/// The path within the repo.
|
||||
pub rfilename: String,
|
||||
}
|
||||
|
||||
/// The description of the repo given by the hub
|
||||
#[derive(Debug, Clone, Deserialize, PartialEq)]
|
||||
pub struct ModelInfo {
|
||||
/// See [`Siblings`]
|
||||
pub siblings: Vec<Siblings>,
|
||||
}
|
||||
|
||||
/// Helper to create [`Api`] with all the options.
|
||||
pub struct ApiBuilder {
|
||||
endpoint: String,
|
||||
cache: Cache,
|
||||
url_template: String,
|
||||
token: Option<String>,
|
||||
chunk_size: usize,
|
||||
parallel_failures: usize,
|
||||
max_retries: usize,
|
||||
progress: bool,
|
||||
}
|
||||
|
||||
impl Default for ApiBuilder {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ApiBuilder {
|
||||
/// Default api builder
|
||||
/// ```
|
||||
/// use candle_hub::api::sync::ApiBuilder;
|
||||
/// let api = ApiBuilder::new().build().unwrap();
|
||||
/// ```
|
||||
pub fn new() -> Self {
|
||||
let cache = Cache::default();
|
||||
let mut token_filename = cache.path().clone();
|
||||
token_filename.push(".token");
|
||||
let token = match std::fs::read_to_string(token_filename) {
|
||||
Ok(token_content) => {
|
||||
let token_content = token_content.trim();
|
||||
if !token_content.is_empty() {
|
||||
Some(token_content.to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
Err(_) => None,
|
||||
};
|
||||
|
||||
let progress = true;
|
||||
|
||||
Self {
|
||||
endpoint: "https://huggingface.co".to_string(),
|
||||
url_template: "{endpoint}/{repo_id}/resolve/{revision}/{filename}".to_string(),
|
||||
cache,
|
||||
token,
|
||||
chunk_size: 10_000_000,
|
||||
parallel_failures: 0,
|
||||
max_retries: 0,
|
||||
progress,
|
||||
}
|
||||
}
|
||||
|
||||
/// Wether to show a progressbar
|
||||
pub fn with_progress(mut self, progress: bool) -> Self {
|
||||
self.progress = progress;
|
||||
self
|
||||
}
|
||||
|
||||
/// Changes the location of the cache directory. Defaults is `~/.cache/huggingface/`.
|
||||
pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self {
|
||||
self.cache = Cache::new(cache_dir);
|
||||
self
|
||||
}
|
||||
|
||||
fn build_headers(&self) -> Result<HeaderMap, ApiError> {
|
||||
let mut headers = HeaderMap::new();
|
||||
let user_agent = format!("unkown/None; {NAME}/{VERSION}; rust/unknown");
|
||||
headers.insert(USER_AGENT, HeaderValue::from_str(&user_agent)?);
|
||||
if let Some(token) = &self.token {
|
||||
headers.insert(
|
||||
AUTHORIZATION,
|
||||
HeaderValue::from_str(&format!("Bearer {token}"))?,
|
||||
);
|
||||
}
|
||||
Ok(headers)
|
||||
}
|
||||
|
||||
/// Consumes the builder and buids the final [`Api`]
|
||||
pub fn build(self) -> Result<Api, ApiError> {
|
||||
let headers = self.build_headers()?;
|
||||
let client = Client::builder().default_headers(headers.clone()).build()?;
|
||||
let no_redirect_client = Client::builder()
|
||||
.redirect(Policy::none())
|
||||
.default_headers(headers)
|
||||
.build()?;
|
||||
Ok(Api {
|
||||
endpoint: self.endpoint,
|
||||
url_template: self.url_template,
|
||||
cache: self.cache,
|
||||
client,
|
||||
|
||||
no_redirect_client,
|
||||
chunk_size: self.chunk_size,
|
||||
parallel_failures: self.parallel_failures,
|
||||
max_retries: self.max_retries,
|
||||
progress: self.progress,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Metadata {
|
||||
commit_hash: String,
|
||||
etag: String,
|
||||
size: usize,
|
||||
}
|
||||
|
||||
/// The actual Api used to interacto with the hub.
|
||||
/// You can inspect repos with [`Api::info`]
|
||||
/// or download files with [`Api::download`]
|
||||
pub struct Api {
|
||||
endpoint: String,
|
||||
url_template: String,
|
||||
cache: Cache,
|
||||
client: Client,
|
||||
no_redirect_client: Client,
|
||||
chunk_size: usize,
|
||||
parallel_failures: usize,
|
||||
max_retries: usize,
|
||||
progress: bool,
|
||||
}
|
||||
|
||||
fn temp_filename() -> PathBuf {
|
||||
let s: String = rand::thread_rng()
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(7)
|
||||
.map(char::from)
|
||||
.collect();
|
||||
let mut path = std::env::temp_dir();
|
||||
path.push(s);
|
||||
path
|
||||
}
|
||||
|
||||
fn make_relative(src: &Path, dst: &Path) -> PathBuf {
|
||||
let path = src;
|
||||
let base = dst;
|
||||
|
||||
if path.is_absolute() != base.is_absolute() {
|
||||
panic!("This function is made to look at absolute paths only");
|
||||
}
|
||||
let mut ita = path.components();
|
||||
let mut itb = base.components();
|
||||
|
||||
loop {
|
||||
match (ita.next(), itb.next()) {
|
||||
(Some(a), Some(b)) if a == b => (),
|
||||
(some_a, _) => {
|
||||
// Ignoring b, because 1 component is the filename
|
||||
// for which we don't need to go back up for relative
|
||||
// filename to work.
|
||||
let mut new_path = PathBuf::new();
|
||||
for _ in itb {
|
||||
new_path.push(Component::ParentDir);
|
||||
}
|
||||
if let Some(a) = some_a {
|
||||
new_path.push(a);
|
||||
for comp in ita {
|
||||
new_path.push(comp);
|
||||
}
|
||||
}
|
||||
return new_path;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn symlink_or_rename(src: &Path, dst: &Path) -> Result<(), std::io::Error> {
|
||||
if dst.exists() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let src = make_relative(src, dst);
|
||||
#[cfg(target_os = "windows")]
|
||||
std::os::windows::fs::symlink_file(src, dst)?;
|
||||
|
||||
#[cfg(target_family = "unix")]
|
||||
std::os::unix::fs::symlink(src, dst)?;
|
||||
|
||||
#[cfg(not(any(target_family = "unix", target_os = "windows")))]
|
||||
std::fs::rename(src, dst)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn jitter() -> usize {
|
||||
thread_rng().gen_range(0..=500)
|
||||
}
|
||||
|
||||
fn exponential_backoff(base_wait_time: usize, n: usize, max: usize) -> usize {
|
||||
(base_wait_time + n.pow(2) + jitter()).min(max)
|
||||
}
|
||||
|
||||
impl Api {
|
||||
/// Creates a default Api, for Api options See [`ApiBuilder`]
|
||||
pub fn new() -> Result<Self, ApiError> {
|
||||
ApiBuilder::new().build()
|
||||
}
|
||||
|
||||
/// Get the fully qualified URL of the remote filename
|
||||
/// ```
|
||||
/// # use candle_hub::{api::sync::Api, Repo};
|
||||
/// let api = Api::new().unwrap();
|
||||
/// let repo = Repo::model("gpt2".to_string());
|
||||
/// let url = api.url(&repo, "model.safetensors");
|
||||
/// assert_eq!(url, "https://huggingface.co/gpt2/resolve/main/model.safetensors");
|
||||
/// ```
|
||||
pub fn url(&self, repo: &Repo, filename: &str) -> String {
|
||||
let endpoint = &self.endpoint;
|
||||
let revision = &repo.url_revision();
|
||||
self.url_template
|
||||
.replace("{endpoint}", endpoint)
|
||||
.replace("{repo_id}", &repo.url())
|
||||
.replace("{revision}", revision)
|
||||
.replace("{filename}", filename)
|
||||
}
|
||||
|
||||
/// Get the underlying api client
|
||||
/// Allows for lower level access
|
||||
pub fn client(&self) -> &Client {
|
||||
&self.client
|
||||
}
|
||||
|
||||
fn metadata(&self, url: &str) -> Result<Metadata, ApiError> {
|
||||
let response = self
|
||||
.no_redirect_client
|
||||
.get(url)
|
||||
.header(RANGE, "bytes=0-0")
|
||||
.send()?;
|
||||
let response = response.error_for_status()?;
|
||||
let headers = response.headers();
|
||||
let header_commit = HeaderName::from_static("x-repo-commit");
|
||||
let header_linked_etag = HeaderName::from_static("x-linked-etag");
|
||||
let header_etag = HeaderName::from_static("etag");
|
||||
|
||||
let etag = match headers.get(&header_linked_etag) {
|
||||
Some(etag) => etag,
|
||||
None => headers
|
||||
.get(&header_etag)
|
||||
.ok_or(ApiError::MissingHeader(header_etag))?,
|
||||
};
|
||||
// Cleaning extra quotes
|
||||
let etag = etag.to_str()?.to_string().replace('"', "");
|
||||
let commit_hash = headers
|
||||
.get(&header_commit)
|
||||
.ok_or(ApiError::MissingHeader(header_commit))?
|
||||
.to_str()?
|
||||
.to_string();
|
||||
|
||||
// The response was redirected o S3 most likely which will
|
||||
// know about the size of the file
|
||||
let response = if response.status().is_redirection() {
|
||||
self.client
|
||||
.get(headers.get(LOCATION).unwrap().to_str()?.to_string())
|
||||
.header(RANGE, "bytes=0-0")
|
||||
.send()?
|
||||
} else {
|
||||
response
|
||||
};
|
||||
let headers = response.headers();
|
||||
let content_range = headers
|
||||
.get(CONTENT_RANGE)
|
||||
.ok_or(ApiError::MissingHeader(CONTENT_RANGE))?
|
||||
.to_str()?;
|
||||
|
||||
let size = content_range
|
||||
.split('/')
|
||||
.last()
|
||||
.ok_or(ApiError::InvalidHeader(CONTENT_RANGE))?
|
||||
.parse()?;
|
||||
Ok(Metadata {
|
||||
commit_hash,
|
||||
etag,
|
||||
size,
|
||||
})
|
||||
}
|
||||
|
||||
fn download_tempfile(
|
||||
&self,
|
||||
url: &str,
|
||||
length: usize,
|
||||
progressbar: Option<ProgressBar>,
|
||||
) -> Result<PathBuf, ApiError> {
|
||||
let filename = temp_filename();
|
||||
|
||||
// Create the file and set everything properly
|
||||
std::fs::File::create(&filename)?.set_len(length as u64)?;
|
||||
|
||||
let chunk_size = self.chunk_size;
|
||||
|
||||
let n_chunks = (length + chunk_size - 1) / chunk_size;
|
||||
let n_threads = num_cpus::get();
|
||||
let chunks_per_thread = (n_chunks + n_threads - 1) / n_threads;
|
||||
let handles = (0..n_threads).map(|thread_id| {
|
||||
let url = url.to_string();
|
||||
let filename = filename.clone();
|
||||
let client = self.client.clone();
|
||||
let parallel_failures = self.parallel_failures;
|
||||
let max_retries = self.max_retries;
|
||||
let progress = progressbar.clone();
|
||||
std::thread::spawn(move || {
|
||||
for chunk_id in chunks_per_thread * thread_id
|
||||
..std::cmp::min(chunks_per_thread * (thread_id + 1), n_chunks)
|
||||
{
|
||||
let start = chunk_id * chunk_size;
|
||||
let stop = std::cmp::min(start + chunk_size - 1, length);
|
||||
let mut chunk = Self::download_chunk(&client, &url, &filename, start, stop);
|
||||
let mut i = 0;
|
||||
if parallel_failures > 0 {
|
||||
while let Err(dlerr) = chunk {
|
||||
let wait_time = exponential_backoff(300, i, 10_000);
|
||||
std::thread::sleep(std::time::Duration::from_millis(wait_time as u64));
|
||||
|
||||
chunk = Self::download_chunk(&client, &url, &filename, start, stop);
|
||||
i += 1;
|
||||
if i > max_retries {
|
||||
return Err(ApiError::TooManyRetries(dlerr.into()));
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(p) = &progress {
|
||||
p.inc((stop - start) as u64);
|
||||
}
|
||||
chunk?
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
});
|
||||
|
||||
let results: Result<Vec<()>, ApiError> =
|
||||
handles.into_iter().flat_map(|h| h.join()).collect();
|
||||
|
||||
results?;
|
||||
if let Some(p) = progressbar {
|
||||
p.finish()
|
||||
}
|
||||
Ok(filename)
|
||||
}
|
||||
|
||||
fn download_chunk(
|
||||
client: &Client,
|
||||
url: &str,
|
||||
filename: &PathBuf,
|
||||
start: usize,
|
||||
stop: usize,
|
||||
) -> Result<(), ApiError> {
|
||||
// Process each socket concurrently.
|
||||
let range = format!("bytes={start}-{stop}");
|
||||
let mut file = std::fs::OpenOptions::new().write(true).open(filename)?;
|
||||
file.seek(SeekFrom::Start(start as u64))?;
|
||||
let response = client
|
||||
.get(url)
|
||||
.header(RANGE, range)
|
||||
.send()?
|
||||
.error_for_status()?;
|
||||
let content = response.bytes()?;
|
||||
file.write_all(&content)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// This will attempt the fetch the file locally first, then [`Api.download`]
|
||||
/// if the file is not present.
|
||||
/// ```no_run
|
||||
/// use candle_hub::{api::sync::ApiBuilder, Repo};
|
||||
/// let api = ApiBuilder::new().build().unwrap();
|
||||
/// let repo = Repo::model("gpt2".to_string());
|
||||
/// let local_filename = api.get(&repo, "model.safetensors").unwrap();
|
||||
pub fn get(&self, repo: &Repo, filename: &str) -> Result<PathBuf, ApiError> {
|
||||
if let Some(path) = self.cache.get(repo, filename) {
|
||||
Ok(path)
|
||||
} else {
|
||||
self.download(repo, filename)
|
||||
}
|
||||
}
|
||||
|
||||
/// Downloads a remote file (if not already present) into the cache directory
|
||||
/// to be used locally.
|
||||
/// This functions require internet access to verify if new versions of the file
|
||||
/// exist, even if a file is already on disk at location.
|
||||
/// ```no_run
|
||||
/// # use candle_hub::{api::sync::ApiBuilder, Repo};
|
||||
/// let api = ApiBuilder::new().build().unwrap();
|
||||
/// let repo = Repo::model("gpt2".to_string());
|
||||
/// let local_filename = api.download(&repo, "model.safetensors").unwrap();
|
||||
/// ```
|
||||
pub fn download(&self, repo: &Repo, filename: &str) -> Result<PathBuf, ApiError> {
|
||||
let url = self.url(repo, filename);
|
||||
let metadata = self.metadata(&url)?;
|
||||
|
||||
let blob_path = self.cache.blob_path(repo, &metadata.etag);
|
||||
std::fs::create_dir_all(blob_path.parent().unwrap())?;
|
||||
|
||||
let progressbar = if self.progress {
|
||||
let progress = ProgressBar::new(metadata.size as u64);
|
||||
progress.set_style(
|
||||
ProgressStyle::with_template(
|
||||
"{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})",
|
||||
)
|
||||
.unwrap(), // .progress_chars("━ "),
|
||||
);
|
||||
let maxlength = 30;
|
||||
let message = if filename.len() > maxlength {
|
||||
format!("..{}", &filename[filename.len() - maxlength..])
|
||||
} else {
|
||||
filename.to_string()
|
||||
};
|
||||
progress.set_message(message);
|
||||
Some(progress)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let tmp_filename = self.download_tempfile(&url, metadata.size, progressbar)?;
|
||||
|
||||
if std::fs::rename(&tmp_filename, &blob_path).is_err() {
|
||||
// Renaming may fail if locations are different mount points
|
||||
std::fs::File::create(&blob_path)?;
|
||||
std::fs::copy(tmp_filename, &blob_path)?;
|
||||
}
|
||||
|
||||
let mut pointer_path = self.cache.pointer_path(repo, &metadata.commit_hash);
|
||||
pointer_path.push(filename);
|
||||
std::fs::create_dir_all(pointer_path.parent().unwrap()).ok();
|
||||
|
||||
symlink_or_rename(&blob_path, &pointer_path)?;
|
||||
self.cache.create_ref(repo, &metadata.commit_hash)?;
|
||||
|
||||
Ok(pointer_path)
|
||||
}
|
||||
|
||||
/// Get information about the Repo
|
||||
/// ```
|
||||
/// use candle_hub::{api::sync::Api, Repo};
|
||||
/// let api = Api::new().unwrap();
|
||||
/// let repo = Repo::model("gpt2".to_string());
|
||||
/// api.info(&repo);
|
||||
/// ```
|
||||
pub fn info(&self, repo: &Repo) -> Result<ModelInfo, ApiError> {
|
||||
let url = format!("{}/api/{}", self.endpoint, repo.api_url());
|
||||
let response = self.client.get(url).send()?;
|
||||
let response = response.error_for_status()?;
|
||||
|
||||
let model_info = response.json()?;
|
||||
|
||||
Ok(model_info)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::RepoType;
|
||||
use rand::{distributions::Alphanumeric, Rng};
|
||||
use sha256::try_digest;
|
||||
|
||||
struct TempDir {
|
||||
path: PathBuf,
|
||||
}
|
||||
|
||||
impl TempDir {
|
||||
pub fn new() -> Self {
|
||||
let s: String = rand::thread_rng()
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(7)
|
||||
.map(char::from)
|
||||
.collect();
|
||||
let mut path = std::env::temp_dir();
|
||||
path.push(s);
|
||||
std::fs::create_dir(&path).unwrap();
|
||||
Self { path }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TempDir {
|
||||
fn drop(&mut self) {
|
||||
std::fs::remove_dir_all(&self.path).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn simple() {
|
||||
let tmp = TempDir::new();
|
||||
let api = ApiBuilder::new()
|
||||
.with_progress(false)
|
||||
.with_cache_dir(tmp.path.clone())
|
||||
.build()
|
||||
.unwrap();
|
||||
let repo = Repo::new("julien-c/dummy-unknown".to_string(), RepoType::Model);
|
||||
let downloaded_path = api.download(&repo, "config.json").unwrap();
|
||||
assert!(downloaded_path.exists());
|
||||
let val = try_digest(&*downloaded_path).unwrap();
|
||||
assert_eq!(
|
||||
val,
|
||||
"b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32"
|
||||
);
|
||||
|
||||
// Make sure the file is now seeable without connection
|
||||
let cache_path = api.cache.get(&repo, "config.json").unwrap();
|
||||
assert_eq!(cache_path, downloaded_path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dataset() {
|
||||
let tmp = TempDir::new();
|
||||
let api = ApiBuilder::new()
|
||||
.with_progress(false)
|
||||
.with_cache_dir(tmp.path.clone())
|
||||
.build()
|
||||
.unwrap();
|
||||
let repo = Repo::with_revision(
|
||||
"wikitext".to_string(),
|
||||
RepoType::Dataset,
|
||||
"refs/convert/parquet".to_string(),
|
||||
);
|
||||
let downloaded_path = api
|
||||
.download(&repo, "wikitext-103-v1/wikitext-test.parquet")
|
||||
.unwrap();
|
||||
assert!(downloaded_path.exists());
|
||||
let val = try_digest(&*downloaded_path).unwrap();
|
||||
assert_eq!(
|
||||
val,
|
||||
"59ce09415ad8aa45a9e34f88cec2548aeb9de9a73fcda9f6b33a86a065f32b90"
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn info() {
|
||||
let tmp = TempDir::new();
|
||||
let api = ApiBuilder::new()
|
||||
.with_progress(false)
|
||||
.with_cache_dir(tmp.path.clone())
|
||||
.build()
|
||||
.unwrap();
|
||||
let repo = Repo::with_revision(
|
||||
"wikitext".to_string(),
|
||||
RepoType::Dataset,
|
||||
"refs/convert/parquet".to_string(),
|
||||
);
|
||||
let model_info = api.info(&repo).unwrap();
|
||||
assert_eq!(
|
||||
model_info,
|
||||
ModelInfo {
|
||||
siblings: vec![
|
||||
Siblings {
|
||||
rfilename: ".gitattributes".to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-103-raw-v1/wikitext-test.parquet".to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-103-raw-v1/wikitext-train-00000-of-00002.parquet"
|
||||
.to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-103-raw-v1/wikitext-train-00001-of-00002.parquet"
|
||||
.to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-103-raw-v1/wikitext-validation.parquet".to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-103-v1/test/index.duckdb".to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-103-v1/validation/index.duckdb".to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-103-v1/wikitext-test.parquet".to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-103-v1/wikitext-train-00000-of-00002.parquet"
|
||||
.to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-103-v1/wikitext-train-00001-of-00002.parquet"
|
||||
.to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-103-v1/wikitext-validation.parquet".to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-2-raw-v1/test/index.duckdb".to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-2-raw-v1/train/index.duckdb".to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-2-raw-v1/validation/index.duckdb".to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-2-raw-v1/wikitext-test.parquet".to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-2-raw-v1/wikitext-train.parquet".to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-2-raw-v1/wikitext-validation.parquet".to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-2-v1/wikitext-test.parquet".to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-2-v1/wikitext-train.parquet".to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-2-v1/wikitext-validation.parquet".to_string()
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
@ -1,723 +0,0 @@
|
||||
use crate::{Cache, Repo};
|
||||
use indicatif::{ProgressBar, ProgressStyle};
|
||||
use rand::{distributions::Alphanumeric, thread_rng, Rng};
|
||||
use reqwest::{
|
||||
header::{
|
||||
HeaderMap, HeaderName, HeaderValue, InvalidHeaderValue, ToStrError, AUTHORIZATION,
|
||||
CONTENT_RANGE, LOCATION, RANGE, USER_AGENT,
|
||||
},
|
||||
redirect::Policy,
|
||||
Client, Error as ReqwestError,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use std::num::ParseIntError;
|
||||
use std::path::{Component, Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use thiserror::Error;
|
||||
use tokio::io::{AsyncSeekExt, AsyncWriteExt, SeekFrom};
|
||||
use tokio::sync::{AcquireError, Semaphore, TryAcquireError};
|
||||
|
||||
/// Current version (used in user-agent)
|
||||
const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
/// Current name (used in user-agent)
|
||||
const NAME: &str = env!("CARGO_PKG_NAME");
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
/// All errors the API can throw
|
||||
pub enum ApiError {
|
||||
/// Api expects certain header to be present in the results to derive some information
|
||||
#[error("Header {0} is missing")]
|
||||
MissingHeader(HeaderName),
|
||||
|
||||
/// The header exists, but the value is not conform to what the Api expects.
|
||||
#[error("Header {0} is invalid")]
|
||||
InvalidHeader(HeaderName),
|
||||
|
||||
/// The value cannot be used as a header during request header construction
|
||||
#[error("Invalid header value {0}")]
|
||||
InvalidHeaderValue(#[from] InvalidHeaderValue),
|
||||
|
||||
/// The header value is not valid utf-8
|
||||
#[error("header value is not a string")]
|
||||
ToStr(#[from] ToStrError),
|
||||
|
||||
/// Error in the request
|
||||
#[error("request error: {0}")]
|
||||
RequestError(#[from] ReqwestError),
|
||||
|
||||
/// Error parsing some range value
|
||||
#[error("Cannot parse int")]
|
||||
ParseIntError(#[from] ParseIntError),
|
||||
|
||||
/// I/O Error
|
||||
#[error("I/O error {0}")]
|
||||
IoError(#[from] std::io::Error),
|
||||
|
||||
/// We tried to download chunk too many times
|
||||
#[error("Too many retries: {0}")]
|
||||
TooManyRetries(Box<ApiError>),
|
||||
|
||||
/// Semaphore cannot be acquired
|
||||
#[error("Try acquire: {0}")]
|
||||
TryAcquireError(#[from] TryAcquireError),
|
||||
|
||||
/// Semaphore cannot be acquired
|
||||
#[error("Acquire: {0}")]
|
||||
AcquireError(#[from] AcquireError),
|
||||
// /// Semaphore cannot be acquired
|
||||
// #[error("Invalid Response: {0:?}")]
|
||||
// InvalidResponse(Response),
|
||||
}
|
||||
|
||||
/// Siblings are simplified file descriptions of remote files on the hub
|
||||
#[derive(Debug, Clone, Deserialize, PartialEq)]
|
||||
pub struct Siblings {
|
||||
/// The path within the repo.
|
||||
pub rfilename: String,
|
||||
}
|
||||
|
||||
/// The description of the repo given by the hub
|
||||
#[derive(Debug, Clone, Deserialize, PartialEq)]
|
||||
pub struct ModelInfo {
|
||||
/// See [`Siblings`]
|
||||
pub siblings: Vec<Siblings>,
|
||||
}
|
||||
|
||||
/// Helper to create [`Api`] with all the options.
|
||||
pub struct ApiBuilder {
|
||||
endpoint: String,
|
||||
cache: Cache,
|
||||
url_template: String,
|
||||
token: Option<String>,
|
||||
max_files: usize,
|
||||
chunk_size: usize,
|
||||
parallel_failures: usize,
|
||||
max_retries: usize,
|
||||
progress: bool,
|
||||
}
|
||||
|
||||
impl Default for ApiBuilder {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ApiBuilder {
|
||||
/// Default api builder
|
||||
/// ```
|
||||
/// use candle_hub::api::tokio::ApiBuilder;
|
||||
/// let api = ApiBuilder::new().build().unwrap();
|
||||
/// ```
|
||||
pub fn new() -> Self {
|
||||
let cache = Cache::default();
|
||||
let mut token_filename = cache.path().clone();
|
||||
token_filename.push(".token");
|
||||
let token = match std::fs::read_to_string(token_filename) {
|
||||
Ok(token_content) => {
|
||||
let token_content = token_content.trim();
|
||||
if !token_content.is_empty() {
|
||||
Some(token_content.to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
Err(_) => None,
|
||||
};
|
||||
|
||||
let progress = true;
|
||||
|
||||
Self {
|
||||
endpoint: "https://huggingface.co".to_string(),
|
||||
url_template: "{endpoint}/{repo_id}/resolve/{revision}/{filename}".to_string(),
|
||||
cache,
|
||||
token,
|
||||
max_files: num_cpus::get(),
|
||||
chunk_size: 10_000_000,
|
||||
parallel_failures: 0,
|
||||
max_retries: 0,
|
||||
progress,
|
||||
}
|
||||
}
|
||||
|
||||
/// Wether to show a progressbar
|
||||
pub fn with_progress(mut self, progress: bool) -> Self {
|
||||
self.progress = progress;
|
||||
self
|
||||
}
|
||||
|
||||
/// Changes the location of the cache directory. Defaults is `~/.cache/huggingface/`.
|
||||
pub fn with_cache_dir(mut self, cache_dir: PathBuf) -> Self {
|
||||
self.cache = Cache::new(cache_dir);
|
||||
self
|
||||
}
|
||||
|
||||
fn build_headers(&self) -> Result<HeaderMap, ApiError> {
|
||||
let mut headers = HeaderMap::new();
|
||||
let user_agent = format!("unkown/None; {NAME}/{VERSION}; rust/unknown");
|
||||
headers.insert(USER_AGENT, HeaderValue::from_str(&user_agent)?);
|
||||
if let Some(token) = &self.token {
|
||||
headers.insert(
|
||||
AUTHORIZATION,
|
||||
HeaderValue::from_str(&format!("Bearer {token}"))?,
|
||||
);
|
||||
}
|
||||
Ok(headers)
|
||||
}
|
||||
|
||||
/// Consumes the builder and buids the final [`Api`]
|
||||
pub fn build(self) -> Result<Api, ApiError> {
|
||||
let headers = self.build_headers()?;
|
||||
let client = Client::builder().default_headers(headers.clone()).build()?;
|
||||
let no_redirect_client = Client::builder()
|
||||
.redirect(Policy::none())
|
||||
.default_headers(headers)
|
||||
.build()?;
|
||||
Ok(Api {
|
||||
endpoint: self.endpoint,
|
||||
url_template: self.url_template,
|
||||
cache: self.cache,
|
||||
client,
|
||||
|
||||
no_redirect_client,
|
||||
max_files: self.max_files,
|
||||
chunk_size: self.chunk_size,
|
||||
parallel_failures: self.parallel_failures,
|
||||
max_retries: self.max_retries,
|
||||
progress: self.progress,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Metadata {
|
||||
commit_hash: String,
|
||||
etag: String,
|
||||
size: usize,
|
||||
}
|
||||
|
||||
/// The actual Api used to interacto with the hub.
|
||||
/// You can inspect repos with [`Api::info`]
|
||||
/// or download files with [`Api::download`]
|
||||
pub struct Api {
|
||||
endpoint: String,
|
||||
url_template: String,
|
||||
cache: Cache,
|
||||
client: Client,
|
||||
no_redirect_client: Client,
|
||||
max_files: usize,
|
||||
chunk_size: usize,
|
||||
parallel_failures: usize,
|
||||
max_retries: usize,
|
||||
progress: bool,
|
||||
}
|
||||
|
||||
fn temp_filename() -> PathBuf {
|
||||
let s: String = rand::thread_rng()
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(7)
|
||||
.map(char::from)
|
||||
.collect();
|
||||
let mut path = std::env::temp_dir();
|
||||
path.push(s);
|
||||
path
|
||||
}
|
||||
|
||||
fn make_relative(src: &Path, dst: &Path) -> PathBuf {
|
||||
let path = src;
|
||||
let base = dst;
|
||||
|
||||
if path.is_absolute() != base.is_absolute() {
|
||||
panic!("This function is made to look at absolute paths only");
|
||||
}
|
||||
let mut ita = path.components();
|
||||
let mut itb = base.components();
|
||||
|
||||
loop {
|
||||
match (ita.next(), itb.next()) {
|
||||
(Some(a), Some(b)) if a == b => (),
|
||||
(some_a, _) => {
|
||||
// Ignoring b, because 1 component is the filename
|
||||
// for which we don't need to go back up for relative
|
||||
// filename to work.
|
||||
let mut new_path = PathBuf::new();
|
||||
for _ in itb {
|
||||
new_path.push(Component::ParentDir);
|
||||
}
|
||||
if let Some(a) = some_a {
|
||||
new_path.push(a);
|
||||
for comp in ita {
|
||||
new_path.push(comp);
|
||||
}
|
||||
}
|
||||
return new_path;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn symlink_or_rename(src: &Path, dst: &Path) -> Result<(), std::io::Error> {
|
||||
if dst.exists() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let src = make_relative(src, dst);
|
||||
#[cfg(target_os = "windows")]
|
||||
std::os::windows::fs::symlink_file(src, dst)?;
|
||||
|
||||
#[cfg(target_family = "unix")]
|
||||
std::os::unix::fs::symlink(src, dst)?;
|
||||
|
||||
#[cfg(not(any(target_family = "unix", target_os = "windows")))]
|
||||
std::fs::rename(src, dst)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn jitter() -> usize {
|
||||
thread_rng().gen_range(0..=500)
|
||||
}
|
||||
|
||||
fn exponential_backoff(base_wait_time: usize, n: usize, max: usize) -> usize {
|
||||
(base_wait_time + n.pow(2) + jitter()).min(max)
|
||||
}
|
||||
|
||||
impl Api {
|
||||
/// Creates a default Api, for Api options See [`ApiBuilder`]
|
||||
pub fn new() -> Result<Self, ApiError> {
|
||||
ApiBuilder::new().build()
|
||||
}
|
||||
|
||||
/// Get the fully qualified URL of the remote filename
|
||||
/// ```
|
||||
/// # use candle_hub::{api::tokio::Api, Repo};
|
||||
/// let api = Api::new().unwrap();
|
||||
/// let repo = Repo::model("gpt2".to_string());
|
||||
/// let url = api.url(&repo, "model.safetensors");
|
||||
/// assert_eq!(url, "https://huggingface.co/gpt2/resolve/main/model.safetensors");
|
||||
/// ```
|
||||
pub fn url(&self, repo: &Repo, filename: &str) -> String {
|
||||
let endpoint = &self.endpoint;
|
||||
let revision = &repo.url_revision();
|
||||
self.url_template
|
||||
.replace("{endpoint}", endpoint)
|
||||
.replace("{repo_id}", &repo.url())
|
||||
.replace("{revision}", revision)
|
||||
.replace("{filename}", filename)
|
||||
}
|
||||
|
||||
/// Get the underlying api client
|
||||
/// Allows for lower level access
|
||||
pub fn client(&self) -> &Client {
|
||||
&self.client
|
||||
}
|
||||
|
||||
async fn metadata(&self, url: &str) -> Result<Metadata, ApiError> {
|
||||
let response = self
|
||||
.no_redirect_client
|
||||
.get(url)
|
||||
.header(RANGE, "bytes=0-0")
|
||||
.send()
|
||||
.await?;
|
||||
let response = response.error_for_status()?;
|
||||
let headers = response.headers();
|
||||
let header_commit = HeaderName::from_static("x-repo-commit");
|
||||
let header_linked_etag = HeaderName::from_static("x-linked-etag");
|
||||
let header_etag = HeaderName::from_static("etag");
|
||||
|
||||
let etag = match headers.get(&header_linked_etag) {
|
||||
Some(etag) => etag,
|
||||
None => headers
|
||||
.get(&header_etag)
|
||||
.ok_or(ApiError::MissingHeader(header_etag))?,
|
||||
};
|
||||
// Cleaning extra quotes
|
||||
let etag = etag.to_str()?.to_string().replace('"', "");
|
||||
let commit_hash = headers
|
||||
.get(&header_commit)
|
||||
.ok_or(ApiError::MissingHeader(header_commit))?
|
||||
.to_str()?
|
||||
.to_string();
|
||||
|
||||
// The response was redirected o S3 most likely which will
|
||||
// know about the size of the file
|
||||
let response = if response.status().is_redirection() {
|
||||
self.client
|
||||
.get(headers.get(LOCATION).unwrap().to_str()?.to_string())
|
||||
.header(RANGE, "bytes=0-0")
|
||||
.send()
|
||||
.await?
|
||||
} else {
|
||||
response
|
||||
};
|
||||
let headers = response.headers();
|
||||
let content_range = headers
|
||||
.get(CONTENT_RANGE)
|
||||
.ok_or(ApiError::MissingHeader(CONTENT_RANGE))?
|
||||
.to_str()?;
|
||||
|
||||
let size = content_range
|
||||
.split('/')
|
||||
.last()
|
||||
.ok_or(ApiError::InvalidHeader(CONTENT_RANGE))?
|
||||
.parse()?;
|
||||
Ok(Metadata {
|
||||
commit_hash,
|
||||
etag,
|
||||
size,
|
||||
})
|
||||
}
|
||||
|
||||
async fn download_tempfile(
|
||||
&self,
|
||||
url: &str,
|
||||
length: usize,
|
||||
progressbar: Option<ProgressBar>,
|
||||
) -> Result<PathBuf, ApiError> {
|
||||
let mut handles = vec![];
|
||||
let semaphore = Arc::new(Semaphore::new(self.max_files));
|
||||
let parallel_failures_semaphore = Arc::new(Semaphore::new(self.parallel_failures));
|
||||
let filename = temp_filename();
|
||||
|
||||
// Create the file and set everything properly
|
||||
tokio::fs::File::create(&filename)
|
||||
.await?
|
||||
.set_len(length as u64)
|
||||
.await?;
|
||||
|
||||
let chunk_size = self.chunk_size;
|
||||
for start in (0..length).step_by(chunk_size) {
|
||||
let url = url.to_string();
|
||||
let filename = filename.clone();
|
||||
let client = self.client.clone();
|
||||
|
||||
let stop = std::cmp::min(start + chunk_size - 1, length);
|
||||
let permit = semaphore.clone().acquire_owned().await?;
|
||||
let parallel_failures = self.parallel_failures;
|
||||
let max_retries = self.max_retries;
|
||||
let parallel_failures_semaphore = parallel_failures_semaphore.clone();
|
||||
let progress = progressbar.clone();
|
||||
handles.push(tokio::spawn(async move {
|
||||
let mut chunk = Self::download_chunk(&client, &url, &filename, start, stop).await;
|
||||
let mut i = 0;
|
||||
if parallel_failures > 0 {
|
||||
while let Err(dlerr) = chunk {
|
||||
let parallel_failure_permit =
|
||||
parallel_failures_semaphore.clone().try_acquire_owned()?;
|
||||
|
||||
let wait_time = exponential_backoff(300, i, 10_000);
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(wait_time as u64))
|
||||
.await;
|
||||
|
||||
chunk = Self::download_chunk(&client, &url, &filename, start, stop).await;
|
||||
i += 1;
|
||||
if i > max_retries {
|
||||
return Err(ApiError::TooManyRetries(dlerr.into()));
|
||||
}
|
||||
drop(parallel_failure_permit);
|
||||
}
|
||||
}
|
||||
drop(permit);
|
||||
if let Some(p) = progress {
|
||||
p.inc((stop - start) as u64);
|
||||
}
|
||||
chunk
|
||||
}));
|
||||
}
|
||||
|
||||
// Output the chained result
|
||||
let results: Vec<Result<Result<(), ApiError>, tokio::task::JoinError>> =
|
||||
futures::future::join_all(handles).await;
|
||||
let results: Result<(), ApiError> = results.into_iter().flatten().collect();
|
||||
results?;
|
||||
if let Some(p) = progressbar {
|
||||
p.finish()
|
||||
}
|
||||
Ok(filename)
|
||||
}
|
||||
|
||||
async fn download_chunk(
|
||||
client: &reqwest::Client,
|
||||
url: &str,
|
||||
filename: &PathBuf,
|
||||
start: usize,
|
||||
stop: usize,
|
||||
) -> Result<(), ApiError> {
|
||||
// Process each socket concurrently.
|
||||
let range = format!("bytes={start}-{stop}");
|
||||
let mut file = tokio::fs::OpenOptions::new()
|
||||
.write(true)
|
||||
.open(filename)
|
||||
.await?;
|
||||
file.seek(SeekFrom::Start(start as u64)).await?;
|
||||
let response = client
|
||||
.get(url)
|
||||
.header(RANGE, range)
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()?;
|
||||
let content = response.bytes().await?;
|
||||
file.write_all(&content).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// This will attempt the fetch the file locally first, then [`Api.download`]
|
||||
/// if the file is not present.
|
||||
/// ```no_run
|
||||
/// # use candle_hub::{api::tokio::ApiBuilder, Repo};
|
||||
/// # tokio_test::block_on(async {
|
||||
/// let api = ApiBuilder::new().build().unwrap();
|
||||
/// let repo = Repo::model("gpt2".to_string());
|
||||
/// let local_filename = api.get(&repo, "model.safetensors").await.unwrap();
|
||||
/// # })
|
||||
pub async fn get(&self, repo: &Repo, filename: &str) -> Result<PathBuf, ApiError> {
|
||||
if let Some(path) = self.cache.get(repo, filename) {
|
||||
Ok(path)
|
||||
} else {
|
||||
self.download(repo, filename).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Downloads a remote file (if not already present) into the cache directory
|
||||
/// to be used locally.
|
||||
/// This functions require internet access to verify if new versions of the file
|
||||
/// exist, even if a file is already on disk at location.
|
||||
/// ```no_run
|
||||
/// # use candle_hub::{api::tokio::ApiBuilder, Repo};
|
||||
/// # tokio_test::block_on(async {
|
||||
/// let api = ApiBuilder::new().build().unwrap();
|
||||
/// let repo = Repo::model("gpt2".to_string());
|
||||
/// let local_filename = api.download(&repo, "model.safetensors").await.unwrap();
|
||||
/// # })
|
||||
/// ```
|
||||
pub async fn download(&self, repo: &Repo, filename: &str) -> Result<PathBuf, ApiError> {
|
||||
let url = self.url(repo, filename);
|
||||
let metadata = self.metadata(&url).await?;
|
||||
|
||||
let blob_path = self.cache.blob_path(repo, &metadata.etag);
|
||||
std::fs::create_dir_all(blob_path.parent().unwrap())?;
|
||||
|
||||
let progressbar = if self.progress {
|
||||
let progress = ProgressBar::new(metadata.size as u64);
|
||||
progress.set_style(
|
||||
ProgressStyle::with_template(
|
||||
"{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})",
|
||||
)
|
||||
.unwrap(), // .progress_chars("━ "),
|
||||
);
|
||||
let maxlength = 30;
|
||||
let message = if filename.len() > maxlength {
|
||||
format!("..{}", &filename[filename.len() - maxlength..])
|
||||
} else {
|
||||
filename.to_string()
|
||||
};
|
||||
progress.set_message(message);
|
||||
Some(progress)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let tmp_filename = self
|
||||
.download_tempfile(&url, metadata.size, progressbar)
|
||||
.await?;
|
||||
|
||||
if tokio::fs::rename(&tmp_filename, &blob_path).await.is_err() {
|
||||
// Renaming may fail if locations are different mount points
|
||||
std::fs::File::create(&blob_path)?;
|
||||
tokio::fs::copy(tmp_filename, &blob_path).await?;
|
||||
}
|
||||
|
||||
let mut pointer_path = self.cache.pointer_path(repo, &metadata.commit_hash);
|
||||
pointer_path.push(filename);
|
||||
std::fs::create_dir_all(pointer_path.parent().unwrap()).ok();
|
||||
|
||||
symlink_or_rename(&blob_path, &pointer_path)?;
|
||||
self.cache.create_ref(repo, &metadata.commit_hash)?;
|
||||
|
||||
Ok(pointer_path)
|
||||
}
|
||||
|
||||
/// Get information about the Repo
|
||||
/// ```
|
||||
/// # use candle_hub::{api::tokio::Api, Repo};
|
||||
/// # tokio_test::block_on(async {
|
||||
/// let api = Api::new().unwrap();
|
||||
/// let repo = Repo::model("gpt2".to_string());
|
||||
/// api.info(&repo);
|
||||
/// # })
|
||||
/// ```
|
||||
pub async fn info(&self, repo: &Repo) -> Result<ModelInfo, ApiError> {
|
||||
let url = format!("{}/api/{}", self.endpoint, repo.api_url());
|
||||
let response = self.client.get(url).send().await?;
|
||||
let response = response.error_for_status()?;
|
||||
|
||||
let model_info = response.json().await?;
|
||||
|
||||
Ok(model_info)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::RepoType;
|
||||
use rand::{distributions::Alphanumeric, Rng};
|
||||
use sha256::try_digest;
|
||||
|
||||
struct TempDir {
|
||||
path: PathBuf,
|
||||
}
|
||||
|
||||
impl TempDir {
|
||||
pub fn new() -> Self {
|
||||
let s: String = rand::thread_rng()
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(7)
|
||||
.map(char::from)
|
||||
.collect();
|
||||
let mut path = std::env::temp_dir();
|
||||
path.push(s);
|
||||
std::fs::create_dir(&path).unwrap();
|
||||
Self { path }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TempDir {
|
||||
fn drop(&mut self) {
|
||||
std::fs::remove_dir_all(&self.path).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn simple() {
|
||||
let tmp = TempDir::new();
|
||||
let api = ApiBuilder::new()
|
||||
.with_progress(false)
|
||||
.with_cache_dir(tmp.path.clone())
|
||||
.build()
|
||||
.unwrap();
|
||||
let repo = Repo::new("julien-c/dummy-unknown".to_string(), RepoType::Model);
|
||||
let downloaded_path = api.download(&repo, "config.json").await.unwrap();
|
||||
assert!(downloaded_path.exists());
|
||||
let val = try_digest(&*downloaded_path).unwrap();
|
||||
assert_eq!(
|
||||
val,
|
||||
"b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32"
|
||||
);
|
||||
|
||||
// Make sure the file is now seeable without connection
|
||||
let cache_path = api.cache.get(&repo, "config.json").unwrap();
|
||||
assert_eq!(cache_path, downloaded_path);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dataset() {
|
||||
let tmp = TempDir::new();
|
||||
let api = ApiBuilder::new()
|
||||
.with_progress(false)
|
||||
.with_cache_dir(tmp.path.clone())
|
||||
.build()
|
||||
.unwrap();
|
||||
let repo = Repo::with_revision(
|
||||
"wikitext".to_string(),
|
||||
RepoType::Dataset,
|
||||
"refs/convert/parquet".to_string(),
|
||||
);
|
||||
let downloaded_path = api
|
||||
.download(&repo, "wikitext-103-v1/wikitext-test.parquet")
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(downloaded_path.exists());
|
||||
let val = try_digest(&*downloaded_path).unwrap();
|
||||
assert_eq!(
|
||||
val,
|
||||
"59ce09415ad8aa45a9e34f88cec2548aeb9de9a73fcda9f6b33a86a065f32b90"
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn info() {
|
||||
let tmp = TempDir::new();
|
||||
let api = ApiBuilder::new()
|
||||
.with_progress(false)
|
||||
.with_cache_dir(tmp.path.clone())
|
||||
.build()
|
||||
.unwrap();
|
||||
let repo = Repo::with_revision(
|
||||
"wikitext".to_string(),
|
||||
RepoType::Dataset,
|
||||
"refs/convert/parquet".to_string(),
|
||||
);
|
||||
let model_info = api.info(&repo).await.unwrap();
|
||||
assert_eq!(
|
||||
model_info,
|
||||
ModelInfo {
|
||||
siblings: vec![
|
||||
Siblings {
|
||||
rfilename: ".gitattributes".to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-103-raw-v1/wikitext-test.parquet".to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-103-raw-v1/wikitext-train-00000-of-00002.parquet"
|
||||
.to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-103-raw-v1/wikitext-train-00001-of-00002.parquet"
|
||||
.to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-103-raw-v1/wikitext-validation.parquet".to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-103-v1/test/index.duckdb".to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-103-v1/validation/index.duckdb".to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-103-v1/wikitext-test.parquet".to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-103-v1/wikitext-train-00000-of-00002.parquet"
|
||||
.to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-103-v1/wikitext-train-00001-of-00002.parquet"
|
||||
.to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-103-v1/wikitext-validation.parquet".to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-2-raw-v1/test/index.duckdb".to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-2-raw-v1/train/index.duckdb".to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-2-raw-v1/validation/index.duckdb".to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-2-raw-v1/wikitext-test.parquet".to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-2-raw-v1/wikitext-train.parquet".to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-2-raw-v1/wikitext-validation.parquet".to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-2-v1/wikitext-test.parquet".to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-2-v1/wikitext-train.parquet".to_string()
|
||||
},
|
||||
Siblings {
|
||||
rfilename: "wikitext-2-v1/wikitext-validation.parquet".to_string()
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
@ -1,197 +0,0 @@
|
||||
#![deny(missing_docs)]
|
||||
//! This crates aims to emulate and be compatible with the
|
||||
//! [huggingface_hub](https://github.com/huggingface/huggingface_hub/) python package.
|
||||
//!
|
||||
//! compatible means the Api should reuse the same files skipping downloads if
|
||||
//! they are already present and whenever this crate downloads or modifies this cache
|
||||
//! it should be consistent with [huggingface_hub](https://github.com/huggingface/huggingface_hub/)
|
||||
//!
|
||||
//! At this time only a limited subset of the functionality is present, the goal is to add new
|
||||
//! features over time
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// The actual Api to interact with the hub.
|
||||
#[cfg(feature = "online")]
|
||||
pub mod api;
|
||||
|
||||
/// The type of repo to interact with
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum RepoType {
|
||||
/// This is a model, usually it consists of weight files and some configuration
|
||||
/// files
|
||||
Model,
|
||||
/// This is a dataset, usually contains data within parquet files
|
||||
Dataset,
|
||||
/// This is a space, usually a demo showcashing a given model or dataset
|
||||
Space,
|
||||
}
|
||||
|
||||
/// A local struct used to fetch information from the cache folder.
|
||||
pub struct Cache {
|
||||
path: PathBuf,
|
||||
}
|
||||
|
||||
impl Cache {
|
||||
/// Creates a new cache object location
|
||||
pub fn new(path: PathBuf) -> Self {
|
||||
Self { path }
|
||||
}
|
||||
|
||||
/// Creates a new cache object location
|
||||
pub fn path(&self) -> &PathBuf {
|
||||
&self.path
|
||||
}
|
||||
|
||||
/// This will get the location of the file within the cache for the remote
|
||||
/// `filename`. Will return `None` if file is not already present in cache.
|
||||
pub fn get(&self, repo: &Repo, filename: &str) -> Option<PathBuf> {
|
||||
let mut commit_path = self.path.clone();
|
||||
commit_path.push(repo.folder_name());
|
||||
commit_path.push("refs");
|
||||
commit_path.push(repo.revision());
|
||||
let commit_hash = std::fs::read_to_string(commit_path).ok()?;
|
||||
let mut pointer_path = self.pointer_path(repo, &commit_hash);
|
||||
pointer_path.push(filename);
|
||||
if pointer_path.exists() {
|
||||
Some(pointer_path)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a reference in the cache directory that points branches to the correct
|
||||
/// commits within the blobs.
|
||||
pub fn create_ref(&self, repo: &Repo, commit_hash: &str) -> Result<(), std::io::Error> {
|
||||
let mut ref_path = self.path.clone();
|
||||
ref_path.push(repo.folder_name());
|
||||
ref_path.push("refs");
|
||||
ref_path.push(repo.revision());
|
||||
// Needs to be done like this because revision might contain `/` creating subfolders here.
|
||||
std::fs::create_dir_all(ref_path.parent().unwrap())?;
|
||||
let mut file1 = std::fs::OpenOptions::new()
|
||||
.write(true)
|
||||
.create(true)
|
||||
.open(&ref_path)?;
|
||||
file1.write_all(commit_hash.trim().as_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "online")]
|
||||
pub(crate) fn blob_path(&self, repo: &Repo, etag: &str) -> PathBuf {
|
||||
let mut blob_path = self.path.clone();
|
||||
blob_path.push(repo.folder_name());
|
||||
blob_path.push("blobs");
|
||||
blob_path.push(etag);
|
||||
blob_path
|
||||
}
|
||||
|
||||
pub(crate) fn pointer_path(&self, repo: &Repo, commit_hash: &str) -> PathBuf {
|
||||
let mut pointer_path = self.path.clone();
|
||||
pointer_path.push(repo.folder_name());
|
||||
pointer_path.push("snapshots");
|
||||
pointer_path.push(commit_hash);
|
||||
pointer_path
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Cache {
|
||||
fn default() -> Self {
|
||||
let path = match std::env::var("HF_HOME") {
|
||||
Ok(home) => home.into(),
|
||||
Err(_) => {
|
||||
let mut cache = dirs::home_dir().expect("Cache directory cannot be found");
|
||||
cache.push(".cache");
|
||||
cache.push("huggingface");
|
||||
cache.push("hub");
|
||||
cache
|
||||
}
|
||||
};
|
||||
Self::new(path)
|
||||
}
|
||||
}
|
||||
|
||||
/// The representation of a repo on the hub.
|
||||
#[allow(dead_code)] // Repo type unused in offline mode
|
||||
pub struct Repo {
|
||||
repo_id: String,
|
||||
repo_type: RepoType,
|
||||
revision: String,
|
||||
}
|
||||
|
||||
impl Repo {
|
||||
/// Repo with the default branch ("main").
|
||||
pub fn new(repo_id: String, repo_type: RepoType) -> Self {
|
||||
Self::with_revision(repo_id, repo_type, "main".to_string())
|
||||
}
|
||||
|
||||
/// fully qualified Repo
|
||||
pub fn with_revision(repo_id: String, repo_type: RepoType, revision: String) -> Self {
|
||||
Self {
|
||||
repo_id,
|
||||
repo_type,
|
||||
revision,
|
||||
}
|
||||
}
|
||||
|
||||
/// Shortcut for [`Repo::new`] with [`RepoType::Model`]
|
||||
pub fn model(repo_id: String) -> Self {
|
||||
Self::new(repo_id, RepoType::Model)
|
||||
}
|
||||
|
||||
/// Shortcut for [`Repo::new`] with [`RepoType::Dataset`]
|
||||
pub fn dataset(repo_id: String) -> Self {
|
||||
Self::new(repo_id, RepoType::Dataset)
|
||||
}
|
||||
|
||||
/// Shortcut for [`Repo::new`] with [`RepoType::Space`]
|
||||
pub fn space(repo_id: String) -> Self {
|
||||
Self::new(repo_id, RepoType::Space)
|
||||
}
|
||||
|
||||
/// The normalized folder nameof the repo within the cache directory
|
||||
pub fn folder_name(&self) -> String {
|
||||
let prefix = match self.repo_type {
|
||||
RepoType::Model => "models",
|
||||
RepoType::Dataset => "datasets",
|
||||
RepoType::Space => "spaces",
|
||||
};
|
||||
format!("{prefix}--{}", self.repo_id).replace('/', "--")
|
||||
}
|
||||
|
||||
/// The revision
|
||||
pub fn revision(&self) -> &str {
|
||||
&self.revision
|
||||
}
|
||||
|
||||
/// The actual URL part of the repo
|
||||
#[cfg(feature = "online")]
|
||||
pub fn url(&self) -> String {
|
||||
match self.repo_type {
|
||||
RepoType::Model => self.repo_id.to_string(),
|
||||
RepoType::Dataset => {
|
||||
format!("datasets/{}", self.repo_id)
|
||||
}
|
||||
RepoType::Space => {
|
||||
format!("spaces/{}", self.repo_id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Revision needs to be url escaped before being used in a URL
|
||||
#[cfg(feature = "online")]
|
||||
pub fn url_revision(&self) -> String {
|
||||
self.revision.replace('/', "%2F")
|
||||
}
|
||||
|
||||
/// Used to compute the repo's url part when accessing the metadata of the repo
|
||||
#[cfg(feature = "online")]
|
||||
pub fn api_url(&self) -> String {
|
||||
let prefix = match self.repo_type {
|
||||
RepoType::Model => "models",
|
||||
RepoType::Dataset => "datasets",
|
||||
RepoType::Space => "spaces",
|
||||
};
|
||||
format!("{prefix}/{}/revision/{}", self.repo_id, self.url_revision())
|
||||
}
|
||||
}
|
@ -12,7 +12,7 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core" }
|
||||
candle-hub = { path = "../candle-hub" }
|
||||
hf-hub = { workspace = true}
|
||||
candle-nn = { path = "../candle-nn" }
|
||||
intel-mkl-src = { workspace = true, optional = true, features = ["mkl-dynamic-lp64-iomp"]}
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
|
Reference in New Issue
Block a user