diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs index 8b13b50b..2ca0daaf 100644 --- a/candle-core/src/pickle.rs +++ b/candle-core/src/pickle.rs @@ -816,7 +816,7 @@ impl PthTensors { /// # Arguments /// * `path` - Path to the pth file. /// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file -/// contains multiple objects and the state_dict is the one we are interested in. +/// contains multiple objects and the state_dict is the one we are interested in. pub fn read_all_with_key>( path: P, key: Option<&str>, diff --git a/candle-examples/examples/mamba-minimal/model.rs b/candle-examples/examples/mamba-minimal/model.rs index 7ebea76a..56563086 100644 --- a/candle-examples/examples/mamba-minimal/model.rs +++ b/candle-examples/examples/mamba-minimal/model.rs @@ -21,7 +21,7 @@ impl Config { } fn dt_rank(&self) -> usize { - (self.d_model + 15) / 16 + self.d_model.div_ceil(16) } fn d_conv(&self) -> usize { diff --git a/candle-nn/src/loss.rs b/candle-nn/src/loss.rs index 03e8524d..7fc349fa 100644 --- a/candle-nn/src/loss.rs +++ b/candle-nn/src/loss.rs @@ -7,7 +7,7 @@ use candle::{Result, Tensor}; /// Arguments /// /// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number -/// of categories. This is expected to contain log probabilities. +/// of categories. This is expected to contain log probabilities. /// * [target]: The ground truth labels as a tensor of u32 of dimension `N`. /// /// The resulting tensor is a scalar containing the average value over the batch. @@ -34,7 +34,7 @@ pub fn nll(inp: &Tensor, target: &Tensor) -> Result { /// Arguments /// /// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number -/// of categories. This is expected to raw logits. +/// of categories. This is expected to raw logits. /// * [target]: The ground truth labels as a tensor of u32 of dimension `N`. /// /// The resulting tensor is a scalar containing the average value over the batch. @@ -56,9 +56,9 @@ pub fn mse(inp: &Tensor, target: &Tensor) -> Result { /// Arguments /// /// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number -/// of categories. This is expected to raw logits. +/// of categories. This is expected to raw logits. /// * [target]: The ground truth labels as a tensor of u32 of dimension `N, C` where `N` is the batch size and `C` the number -/// of categories. +/// of categories. /// /// The resulting tensor is a scalar containing the average value over the batch. pub fn binary_cross_entropy_with_logit(inp: &Tensor, target: &Tensor) -> Result { diff --git a/candle-transformers/src/models/dac.rs b/candle-transformers/src/models/dac.rs index 78728b4d..d8465567 100644 --- a/candle-transformers/src/models/dac.rs +++ b/candle-transformers/src/models/dac.rs @@ -104,7 +104,7 @@ impl EncoderBlock { let snake1 = Snake1d::new(dim / 2, vb.pp(3))?; let cfg1 = Conv1dConfig { stride, - padding: (stride + 1) / 2, + padding: stride.div_ceil(2), ..Default::default() }; let conv1 = encodec::conv1d_weight_norm(dim / 2, dim, 2 * stride, cfg1, vb.pp(4))?; @@ -196,7 +196,7 @@ impl DecoderBlock { let snake1 = Snake1d::new(in_dim, vb.pp(0))?; let cfg = ConvTranspose1dConfig { stride, - padding: (stride + 1) / 2, + padding: stride.div_ceil(2), ..Default::default() }; let conv_tr1 = encodec::conv_transpose1d_weight_norm( diff --git a/candle-transformers/src/models/flux/sampling.rs b/candle-transformers/src/models/flux/sampling.rs index f3f0eafd..cdfef043 100644 --- a/candle-transformers/src/models/flux/sampling.rs +++ b/candle-transformers/src/models/flux/sampling.rs @@ -6,8 +6,8 @@ pub fn get_noise( width: usize, device: &Device, ) -> Result { - let height = (height + 15) / 16 * 2; - let width = (width + 15) / 16 * 2; + let height = height.div_ceil(16) * 2; + let width = width.div_ceil(16) * 2; Tensor::randn(0f32, 1., (num_samples, 16, height, width), device) } @@ -84,8 +84,8 @@ pub fn get_schedule(num_steps: usize, shift: Option<(usize, f64, f64)>) -> Vec Result { let (b, _h_w, c_ph_pw) = xs.dims3()?; - let height = (height + 15) / 16; - let width = (width + 15) / 16; + let height = height.div_ceil(16); + let width = width.div_ceil(16); xs.reshape((b, height, width, c_ph_pw / 4, 2, 2))? // (b, h, w, c, ph, pw) .permute((0, 3, 1, 4, 2, 5))? // (b, c, h, ph, w, pw) .reshape((b, c_ph_pw / 4, height * 2, width * 2)) diff --git a/candle-transformers/src/models/mamba.rs b/candle-transformers/src/models/mamba.rs index a29f2619..dfae0af3 100644 --- a/candle-transformers/src/models/mamba.rs +++ b/candle-transformers/src/models/mamba.rs @@ -27,7 +27,7 @@ impl Config { } fn dt_rank(&self) -> usize { - (self.d_model + 15) / 16 + self.d_model.div_ceil(16) } fn d_inner(&self) -> usize { diff --git a/candle-transformers/src/models/metavoice.rs b/candle-transformers/src/models/metavoice.rs index 92d3ffba..66896388 100644 --- a/candle-transformers/src/models/metavoice.rs +++ b/candle-transformers/src/models/metavoice.rs @@ -716,7 +716,7 @@ pub mod transformer { None => { let hidden_dim = self.dim * 4; let n_hidden = ((2 * hidden_dim) as f64 / 3.) as usize; - (n_hidden + 255) / 256 * 256 + n_hidden.div_ceil(256) * 256 } } } diff --git a/candle-transformers/src/models/whisper/audio.rs b/candle-transformers/src/models/whisper/audio.rs index 8490533c..cd04e16f 100644 --- a/candle-transformers/src/models/whisper/audio.rs +++ b/candle-transformers/src/models/whisper/audio.rs @@ -198,7 +198,7 @@ pub fn log_mel_spectrogram_( let samples = { let mut samples_padded = samples.to_vec(); let to_add = n_len * fft_step - samples.len(); - samples_padded.extend(std::iter::repeat(zero).take(to_add)); + samples_padded.extend(std::iter::repeat_n(zero, to_add)); samples_padded }; diff --git a/candle-wasm-examples/whisper/src/audio.rs b/candle-wasm-examples/whisper/src/audio.rs index b87f7df1..d3c0bb7e 100644 --- a/candle-wasm-examples/whisper/src/audio.rs +++ b/candle-wasm-examples/whisper/src/audio.rs @@ -177,7 +177,7 @@ fn log_mel_spectrogram_( let samples = { let mut samples_padded = samples.to_vec(); let to_add = n_len * fft_step - samples.len(); - samples_padded.extend(std::iter::repeat(zero).take(to_add)); + samples_padded.extend(std::iter::repeat_n(zero, to_add)); samples_padded };