mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Lint fixes introduced with Rust 1.83 (#2646)
* Fixes for lint errors introduced with Rust 1.83 * rustfmt * Fix more lints. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:

committed by
GitHub

parent
23ed8a9ded
commit
54e7fc3c97
@ -66,7 +66,7 @@ impl Map2U8 for Cmp {
|
||||
|
||||
struct WCond<'a, T: IntDType>(&'a [T], &'a Layout);
|
||||
|
||||
impl<'a, I: IntDType> Map2 for WCond<'a, I> {
|
||||
impl<I: IntDType> Map2 for WCond<'_, I> {
|
||||
const OP: &'static str = "where";
|
||||
#[inline(always)]
|
||||
fn f<T: WithDType>(&self, t: &[T], t_l: &Layout, f: &[T], f_l: &Layout) -> Result<Vec<T>> {
|
||||
@ -216,7 +216,7 @@ struct ReduceSum<'a> {
|
||||
reduce_dims_and_stride: Vec<(usize, usize)>,
|
||||
}
|
||||
|
||||
impl<'a> ReduceSum<'a> {
|
||||
impl ReduceSum<'_> {
|
||||
#[inline(always)]
|
||||
fn fold_impl<T>(&self, src: &[T], src_l: &Layout, start_elt: T) -> Result<Vec<T>>
|
||||
where
|
||||
@ -281,7 +281,7 @@ impl<'a> ReduceSum<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Map1 for ReduceSum<'a> {
|
||||
impl Map1 for ReduceSum<'_> {
|
||||
#[inline(always)]
|
||||
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||
self.fold_impl(src, src_l, T::zero())
|
||||
@ -454,7 +454,7 @@ struct Gather<'a, I: IntDType> {
|
||||
dim: usize,
|
||||
}
|
||||
|
||||
impl<'a, I: IntDType> Map1 for Gather<'a, I> {
|
||||
impl<I: IntDType> Map1 for Gather<'_, I> {
|
||||
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||
let ids = match self.ids_l.contiguous_offsets() {
|
||||
Some((a, b)) => &self.ids[a..b],
|
||||
@ -507,7 +507,7 @@ struct IndexSelect<'a, T: IntDType> {
|
||||
dim: usize,
|
||||
}
|
||||
|
||||
impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> {
|
||||
impl<I: IntDType> Map1 for IndexSelect<'_, I> {
|
||||
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||
let src = match layout.contiguous_offsets() {
|
||||
Some((a, b)) => &src[a..b],
|
||||
@ -560,7 +560,7 @@ struct ScatterAdd<'a, I: IntDType> {
|
||||
dim: usize,
|
||||
}
|
||||
|
||||
impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
|
||||
impl<I: IntDType> Map2 for ScatterAdd<'_, I> {
|
||||
const OP: &'static str = "scatter-add";
|
||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||
let dst_len = l1.shape().elem_count();
|
||||
@ -616,7 +616,7 @@ struct IndexAdd<'a, I: IntDType> {
|
||||
dim: usize,
|
||||
}
|
||||
|
||||
impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
|
||||
impl<I: IntDType> Map2 for IndexAdd<'_, I> {
|
||||
const OP: &'static str = "index-add";
|
||||
// https://pytorch.org/docs/stable/generated/torch.Tensor.index_add_.html#torch.Tensor.index_add_
|
||||
// v1, l1 -> self
|
||||
@ -736,7 +736,7 @@ fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l
|
||||
|
||||
struct Conv1D<'a>(&'a crate::conv::ParamsConv1D);
|
||||
|
||||
impl<'a> Map2 for Conv1D<'a> {
|
||||
impl Map2 for Conv1D<'_> {
|
||||
const OP: &'static str = "conv1d";
|
||||
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
|
||||
let p = self.0;
|
||||
@ -960,7 +960,7 @@ impl Map1 for Col2Im1D {
|
||||
|
||||
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
|
||||
|
||||
impl<'a> Map2 for ConvTranspose1D<'a> {
|
||||
impl Map2 for ConvTranspose1D<'_> {
|
||||
const OP: &'static str = "conv_transpose1d";
|
||||
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
|
||||
let p = self.0;
|
||||
@ -1029,7 +1029,7 @@ impl<'a> Map2 for ConvTranspose1D<'a> {
|
||||
|
||||
struct Conv2D<'a>(&'a crate::conv::ParamsConv2D);
|
||||
|
||||
impl<'a> Map2 for Conv2D<'a> {
|
||||
impl Map2 for Conv2D<'_> {
|
||||
const OP: &'static str = "conv2d";
|
||||
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
|
||||
let p = self.0;
|
||||
@ -1117,7 +1117,7 @@ impl<'a> Map2 for Conv2D<'a> {
|
||||
|
||||
struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
|
||||
|
||||
impl<'a> Map2 for ConvTranspose2D<'a> {
|
||||
impl Map2 for ConvTranspose2D<'_> {
|
||||
const OP: &'static str = "conv_transpose2d";
|
||||
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
|
||||
let p = self.0;
|
||||
|
@ -457,7 +457,7 @@ impl Content {
|
||||
Some(Value::I32(v)) if *v >= 0 => *v as u64,
|
||||
_ => DEFAULT_ALIGNMENT,
|
||||
};
|
||||
let tensor_data_offset = (position + alignment - 1) / alignment * alignment;
|
||||
let tensor_data_offset = position.div_ceil(alignment) * alignment;
|
||||
Ok(Self {
|
||||
magic,
|
||||
metadata,
|
||||
|
@ -1850,8 +1850,8 @@ pub fn matmul<T: GgmlType>(
|
||||
crate::bail!("unexpected lhs length {} {mkn:?}", lhs.len());
|
||||
}
|
||||
|
||||
let k_in_lhs_blocks = (k + T::BLCK_SIZE - 1) / T::BLCK_SIZE;
|
||||
let k_in_rhs_blocks = (k + T::VecDotType::BLCK_SIZE - 1) / T::VecDotType::BLCK_SIZE;
|
||||
let k_in_lhs_blocks = k.div_ceil(T::BLCK_SIZE);
|
||||
let k_in_rhs_blocks = k.div_ceil(T::VecDotType::BLCK_SIZE);
|
||||
// TODO: Do not make this copy if the DotType is f32.
|
||||
// TODO: Pre-allocate this.
|
||||
let mut lhs_b = vec![T::VecDotType::zeros(); m * k_in_lhs_blocks];
|
||||
|
@ -182,7 +182,7 @@ pub trait Load {
|
||||
fn load(&self, device: &Device) -> Result<Tensor>;
|
||||
}
|
||||
|
||||
impl<'a> Load for st::TensorView<'a> {
|
||||
impl Load for st::TensorView<'_> {
|
||||
fn load(&self, device: &Device) -> Result<Tensor> {
|
||||
convert(self, device)
|
||||
}
|
||||
|
@ -32,7 +32,7 @@ impl<'a> StridedIndex<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Iterator for StridedIndex<'a> {
|
||||
impl Iterator for StridedIndex<'_> {
|
||||
type Item = usize;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
|
@ -87,7 +87,7 @@ impl<'a> DatasetRandomIter<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Iterator for DatasetRandomIter<'a> {
|
||||
impl Iterator for DatasetRandomIter<'_> {
|
||||
type Item = Result<(Tensor, Tensor)>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
|
@ -17,7 +17,7 @@ pub struct Config {
|
||||
impl Config {
|
||||
fn vocab_size(&self) -> usize {
|
||||
let pad = self.pad_vocab_size_multiple;
|
||||
(self.vocab_size + pad - 1) / pad * pad
|
||||
self.vocab_size.div_ceil(pad) * pad
|
||||
}
|
||||
|
||||
fn dt_rank(&self) -> usize {
|
||||
|
@ -6,7 +6,6 @@ pub const IMAGENET_STD: [f32; 3] = [0.229f32, 0.224, 0.225];
|
||||
/// Loads an image from disk using the image crate at the requested resolution,
|
||||
/// using the given std and mean parameters.
|
||||
/// This returns a tensor with shape (3, res, res). imagenet normalization is applied.
|
||||
|
||||
pub fn load_image_with_std_mean<P: AsRef<std::path::Path>>(
|
||||
p: P,
|
||||
res: usize,
|
||||
|
@ -372,7 +372,7 @@ pub fn call_unary_contiguous_tiled(
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
let tile_size = 2;
|
||||
let tiles = (length + tile_size - 1) / tile_size;
|
||||
let tiles = length.div_ceil(tile_size);
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
@ -594,7 +594,7 @@ pub fn call_reduce_contiguous(
|
||||
|
||||
let width = std::cmp::min(
|
||||
pipeline.max_total_threads_per_threadgroup(),
|
||||
(elements_to_sum as u64 + 2 - 1) / 2,
|
||||
(elements_to_sum as u64).div_ceil(2),
|
||||
)
|
||||
.next_power_of_two();
|
||||
|
||||
@ -1735,7 +1735,7 @@ pub fn call_sdpa_full(
|
||||
}
|
||||
};
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name)?;
|
||||
let pipeline = kernels.load_pipeline(device, Source::Sdpa, name)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
@ -1759,16 +1759,16 @@ pub fn call_sdpa_full(
|
||||
let ldo = dk;
|
||||
|
||||
let tn = 1;
|
||||
let tm = (m + BM - 1) / BM;
|
||||
let tm = m.div_ceil(BM);
|
||||
|
||||
let b_stride_q = dk * qseq;
|
||||
let b_stride_k = dk * qseq;
|
||||
let b_stride_v = dk * qseq;
|
||||
let b_stride_o = dk * qseq;
|
||||
let swizzle_log = 0;
|
||||
let gemm_n_iterations_aligned = (n + BN - 1) / BN;
|
||||
let gemm_k_iterations_aligned = (k + bk - 1) / bk;
|
||||
let gemm_sv_m_block_iterations = (m + BM - 1) / BM;
|
||||
let gemm_n_iterations_aligned = n.div_ceil(BN);
|
||||
let gemm_k_iterations_aligned = k.div_ceil(*bk);
|
||||
let gemm_sv_m_block_iterations = m.div_ceil(BM);
|
||||
let batch_ndim = batch_shape.len();
|
||||
|
||||
let alpha = if softcapping != 1. {
|
||||
@ -1906,7 +1906,7 @@ pub fn call_sdpa_vector(
|
||||
alpha
|
||||
};
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name)?;
|
||||
let pipeline = kernels.load_pipeline(device, Source::Sdpa, name)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
@ -1933,7 +1933,7 @@ pub fn call_sdpa_vector(
|
||||
let grid_dims = MTLSize {
|
||||
width: 1,
|
||||
height: b as u64,
|
||||
depth: 1 as u64,
|
||||
depth: 1_u64,
|
||||
};
|
||||
let group_dims = MTLSize {
|
||||
width: 1024,
|
||||
@ -2320,7 +2320,7 @@ pub fn call_quantized_matmul_mv_t(
|
||||
}
|
||||
|
||||
fn divide(m: usize, b: usize) -> NSUInteger {
|
||||
((m + b - 1) / b) as NSUInteger
|
||||
m.div_ceil(b) as NSUInteger
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
|
@ -8,7 +8,7 @@ use std::ffi::c_void;
|
||||
pub(crate) fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
|
||||
let size = length as u64;
|
||||
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size);
|
||||
let count = (size + width - 1) / width;
|
||||
let count = size.div_ceil(width);
|
||||
let thread_group_count = MTLSize {
|
||||
width: count,
|
||||
height: 1,
|
||||
@ -128,7 +128,7 @@ impl EncoderParam for (&Buffer, usize) {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> EncoderParam for &BufferOffset<'a> {
|
||||
impl EncoderParam for &BufferOffset<'_> {
|
||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||
encoder.set_buffer(position, Some(data.buffer), data.offset_in_bytes as u64);
|
||||
}
|
||||
@ -169,7 +169,7 @@ pub struct WrappedEncoder<'a> {
|
||||
end_encoding_on_drop: bool,
|
||||
}
|
||||
|
||||
impl<'a> Drop for WrappedEncoder<'a> {
|
||||
impl Drop for WrappedEncoder<'_> {
|
||||
fn drop(&mut self) {
|
||||
if self.end_encoding_on_drop {
|
||||
self.inner.end_encoding()
|
||||
@ -177,14 +177,15 @@ impl<'a> Drop for WrappedEncoder<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> AsRef<metal::ComputeCommandEncoderRef> for WrappedEncoder<'a> {
|
||||
impl AsRef<metal::ComputeCommandEncoderRef> for WrappedEncoder<'_> {
|
||||
fn as_ref(&self) -> &metal::ComputeCommandEncoderRef {
|
||||
self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl EncoderProvider for &metal::CommandBuffer {
|
||||
type Encoder<'a> = WrappedEncoder<'a>
|
||||
type Encoder<'a>
|
||||
= WrappedEncoder<'a>
|
||||
where
|
||||
Self: 'a;
|
||||
fn encoder(&self) -> Self::Encoder<'_> {
|
||||
@ -196,7 +197,8 @@ impl EncoderProvider for &metal::CommandBuffer {
|
||||
}
|
||||
|
||||
impl EncoderProvider for &metal::CommandBufferRef {
|
||||
type Encoder<'a> = WrappedEncoder<'a>
|
||||
type Encoder<'a>
|
||||
= WrappedEncoder<'a>
|
||||
where
|
||||
Self: 'a;
|
||||
fn encoder(&self) -> Self::Encoder<'_> {
|
||||
@ -208,7 +210,8 @@ impl EncoderProvider for &metal::CommandBufferRef {
|
||||
}
|
||||
|
||||
impl EncoderProvider for &ComputeCommandEncoderRef {
|
||||
type Encoder<'a> = WrappedEncoder<'a>
|
||||
type Encoder<'a>
|
||||
= WrappedEncoder<'a>
|
||||
where
|
||||
Self: 'a;
|
||||
fn encoder(&self) -> Self::Encoder<'_> {
|
||||
|
@ -9,7 +9,7 @@ pub struct Func<'a> {
|
||||
f: Arc<dyn 'a + Fn(&Tensor) -> Result<Tensor> + Send + Sync>,
|
||||
}
|
||||
|
||||
impl<'a> std::fmt::Debug for Func<'a> {
|
||||
impl std::fmt::Debug for Func<'_> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "func")
|
||||
}
|
||||
@ -22,7 +22,7 @@ where
|
||||
Func { f: Arc::new(f) }
|
||||
}
|
||||
|
||||
impl<'a> super::Module for Func<'a> {
|
||||
impl super::Module for Func<'_> {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
(*self.f)(xs)
|
||||
}
|
||||
@ -44,7 +44,7 @@ pub struct FuncT<'a> {
|
||||
f: Arc<dyn 'a + Fn(&Tensor, bool) -> Result<Tensor> + Send + Sync>,
|
||||
}
|
||||
|
||||
impl<'a> std::fmt::Debug for FuncT<'a> {
|
||||
impl std::fmt::Debug for FuncT<'_> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "func")
|
||||
}
|
||||
@ -57,7 +57,7 @@ where
|
||||
FuncT { f: Arc::new(f) }
|
||||
}
|
||||
|
||||
impl<'a> super::ModuleT for FuncT<'a> {
|
||||
impl super::ModuleT for FuncT<'_> {
|
||||
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
|
||||
(*self.f)(xs, train)
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ pub struct VarBuilderArgs<'a, B: Backend> {
|
||||
_phantom: std::marker::PhantomData<&'a B>,
|
||||
}
|
||||
|
||||
impl<'a, B: Backend> Clone for VarBuilderArgs<'a, B> {
|
||||
impl<B: Backend> Clone for VarBuilderArgs<'_, B> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
data: self.data.clone(),
|
||||
@ -76,7 +76,7 @@ pub trait SimpleBackend: Send + Sync {
|
||||
fn contains_tensor(&self, name: &str) -> bool;
|
||||
}
|
||||
|
||||
impl<'a> Backend for Box<dyn SimpleBackend + 'a> {
|
||||
impl Backend for Box<dyn SimpleBackend + '_> {
|
||||
type Hints = crate::Init;
|
||||
fn get(
|
||||
&self,
|
||||
@ -94,7 +94,7 @@ impl<'a> Backend for Box<dyn SimpleBackend + 'a> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, B: Backend> VarBuilderArgs<'a, B> {
|
||||
impl<B: Backend> VarBuilderArgs<'_, B> {
|
||||
pub fn new_with_args(backend: B, dtype: DType, dev: &Device) -> Self {
|
||||
let data = TensorData {
|
||||
backend,
|
||||
@ -286,7 +286,7 @@ pub struct SafeTensorWithRouting<'a> {
|
||||
safetensors: Vec<SafeTensors<'a>>,
|
||||
}
|
||||
|
||||
impl<'a> SimpleBackend for SafeTensorWithRouting<'a> {
|
||||
impl SimpleBackend for SafeTensorWithRouting<'_> {
|
||||
fn get(
|
||||
&self,
|
||||
s: Shape,
|
||||
@ -439,7 +439,7 @@ impl SimpleBackend for candle::safetensors::BufferedSafetensors {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> SimpleBackend for candle::safetensors::SliceSafetensors<'a> {
|
||||
impl SimpleBackend for candle::safetensors::SliceSafetensors<'_> {
|
||||
fn get(
|
||||
&self,
|
||||
s: Shape,
|
||||
@ -732,7 +732,7 @@ pub struct Rename<'a, R: Renamer> {
|
||||
renamer: R,
|
||||
}
|
||||
|
||||
impl<'a, R: Renamer + Sync + Send> SimpleBackend for Rename<'a, R> {
|
||||
impl<R: Renamer + Sync + Send> SimpleBackend for Rename<'_, R> {
|
||||
fn get(
|
||||
&self,
|
||||
s: Shape,
|
||||
|
@ -276,7 +276,7 @@ impl PyTensor {
|
||||
/// &RETURNS&: _ArrayLike
|
||||
fn values(&self, py: Python<'_>) -> PyResult<PyObject> {
|
||||
struct M<'a>(Python<'a>);
|
||||
impl<'a> MapDType for M<'a> {
|
||||
impl MapDType for M<'_> {
|
||||
type Output = PyObject;
|
||||
fn f<T: PyWithDType>(&self, t: &Tensor) -> PyResult<Self::Output> {
|
||||
match t.rank() {
|
||||
|
@ -21,8 +21,8 @@ fn conv2d_same(
|
||||
let module = candle_nn::func(move |xs| {
|
||||
let ih = xs.dim(2)?;
|
||||
let iw = xs.dim(3)?;
|
||||
let oh = (ih + s - 1) / s;
|
||||
let ow = (iw + s - 1) / s;
|
||||
let oh = ih.div_ceil(s);
|
||||
let ow = iw.div_ceil(s);
|
||||
let pad_h = usize::max((oh - 1) * s + k - ih, 0);
|
||||
let pad_w = usize::max((ow - 1) * s + k - iw, 0);
|
||||
if pad_h > 0 || pad_w > 0 {
|
||||
|
@ -543,7 +543,7 @@ impl<'a> DepthAnythingV2<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Module for DepthAnythingV2<'a> {
|
||||
impl Module for DepthAnythingV2<'_> {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let features = self.pretrained.get_intermediate_layers(
|
||||
xs,
|
||||
|
@ -125,8 +125,8 @@ impl Module for Conv2DSame {
|
||||
let s = self.s;
|
||||
let k = self.k;
|
||||
let (_, _, ih, iw) = xs.dims4()?;
|
||||
let oh = (ih + s - 1) / s;
|
||||
let ow = (iw + s - 1) / s;
|
||||
let oh = ih.div_ceil(s);
|
||||
let ow = iw.div_ceil(s);
|
||||
let pad_h = usize::max((oh - 1) * s + k - ih, 0);
|
||||
let pad_w = usize::max((ow - 1) * s + k - iw, 0);
|
||||
if pad_h > 0 || pad_w > 0 {
|
||||
|
@ -89,7 +89,7 @@ impl Config {
|
||||
|
||||
fn frame_rate(&self) -> usize {
|
||||
let hop_length: usize = self.upsampling_ratios.iter().product();
|
||||
(self.sampling_rate + hop_length - 1) / hop_length
|
||||
self.sampling_rate.div_ceil(hop_length)
|
||||
}
|
||||
|
||||
fn num_quantizers(&self) -> usize {
|
||||
|
@ -23,7 +23,7 @@ pub struct Config {
|
||||
impl Config {
|
||||
fn vocab_size(&self) -> usize {
|
||||
let pad = self.pad_vocab_size_multiple;
|
||||
(self.vocab_size + pad - 1) / pad * pad
|
||||
self.vocab_size.div_ceil(pad) * pad
|
||||
}
|
||||
|
||||
fn dt_rank(&self) -> usize {
|
||||
|
@ -21,7 +21,7 @@ struct LinearInterpolator<'x, 'y> {
|
||||
cache: usize,
|
||||
}
|
||||
|
||||
impl<'x, 'y> LinearInterpolator<'x, 'y> {
|
||||
impl LinearInterpolator<'_, '_> {
|
||||
fn accel_find(&mut self, x: f64) -> usize {
|
||||
let xidx = self.cache;
|
||||
if x < self.xp[xidx] {
|
||||
|
Reference in New Issue
Block a user