From eae646d3224059019ff4beeaa78285cdac88eb12 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 12 Jul 2023 12:12:34 +0100 Subject: [PATCH] Use arange in the examples. (#146) --- candle-examples/examples/falcon/model.rs | 3 +-- candle-examples/examples/llama/main.rs | 7 +++---- candle-examples/examples/whisper/model.rs | 5 +++-- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/candle-examples/examples/falcon/model.rs b/candle-examples/examples/falcon/model.rs index f97fe219..82c5d4b2 100644 --- a/candle-examples/examples/falcon/model.rs +++ b/candle-examples/examples/falcon/model.rs @@ -166,8 +166,7 @@ impl FalconRotaryEmbedding { } _ => {} } - let t: Vec<_> = (0..seq_len).map(|c| c as u32).collect(); - let t = Tensor::new(t.as_slice(), device)?.to_dtype(dtype)?; + let t = Tensor::arange(0, seq_len as u32, device)?.to_dtype(dtype)?; let inv_freq = self.inv_freq.to_dtype(dtype)?; let freqs = t.unsqueeze(1)?.matmul(&inv_freq.unsqueeze(0)?)?; let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?; diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index 3e8d2b1a..d21094a4 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -86,11 +86,10 @@ fn precompute_freqs_cis(config: &Config, device: &Device) -> Result { .step_by(2) .map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32)) .collect(); - let arange: Vec<_> = (0..MAX_SEQ_LEN).map(|c| c as f32).collect(); let theta = Tensor::new(theta.as_slice(), device)?; - let arange = Tensor::new(arange.as_slice(), device)?; - let idx_theta = arange - .reshape((arange.elem_count(), 1))? + let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)? + .to_dtype(DType::F32)? + .reshape((MAX_SEQ_LEN, 1))? .matmul(&theta.reshape((1, theta.elem_count()))?)?; let shape = [1, MAX_SEQ_LEN, n_elem / 2, 1]; let idx_theta_cos = idx_theta.cos()?.reshape(&shape)?; diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs index ece8b2d8..d4553e79 100644 --- a/candle-examples/examples/whisper/model.rs +++ b/candle-examples/examples/whisper/model.rs @@ -217,9 +217,10 @@ fn sinusoids(length: usize, channels: usize) -> Result { let inv_timescales: Vec<_> = (0..channels / 2) .map(|i| (i as f32 * (-log_timescale_increment)).exp()) .collect(); - let arange: Vec<_> = (0..length).map(|c| c as f32).collect(); let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?; - let arange = Tensor::new(arange.as_slice(), &Device::Cpu)?.unsqueeze(1)?; + let arange = Tensor::arange(0, length as u32, &Device::Cpu)? + .to_dtype(candle::DType::F32)? + .unsqueeze(1)?; let sh = (length, channels / 2); let scaled_time = (arange.broadcast_as(sh)? * inv_timescales.broadcast_as(sh)?)?; let sincos = Tensor::cat(&[scaled_time.sin()?, scaled_time.cos()?], 1)?;