From 20599172acbefe7e77895c403e87141c4e1cb274 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 12 Jul 2023 12:03:01 +0100 Subject: [PATCH] Add from_iter and arange, use it in the doctests. (#145) --- candle-core/src/dtype.rs | 2 +- candle-core/src/lib.rs | 4 +-- candle-core/src/tensor.rs | 38 ++++++++++++++++++++++++-- candle-examples/examples/llama/main.rs | 1 - 4 files changed, 39 insertions(+), 6 deletions(-) diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index 632abcc4..3a3f59f5 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -53,7 +53,7 @@ impl DType { } } -pub trait WithDType: Sized + Copy + num_traits::NumAssign + 'static { +pub trait WithDType: Sized + Copy + num_traits::NumAssign + std::cmp::PartialOrd + 'static { const DTYPE: DType; fn from_f64(v: f64) -> Self; diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index a51c8e29..0108c198 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -5,8 +5,8 @@ //! # use candle::Error; //! # fn main() -> Result<(), Error>{ //! -//! let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; -//! let b = Tensor::zeros((3, 4), DType::F32, &Device::Cpu)?; +//! let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?; +//! let b = Tensor::arange(0f32, 12f32, &Device::Cpu)?.reshape((3, 4))?; //! //! let c = a.matmul(&b)?; //! # Ok(())} diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index f9a6ebb5..947c6b77 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -39,8 +39,8 @@ impl AsRef for Tensor { /// ```rust /// use candle::{Tensor, DType, Device}; /// -/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?; -/// let b = Tensor::zeros((3, 4), DType::F32, &Device::Cpu)?; +/// let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?; +/// let b = Tensor::arange(0f32, 12f32, &Device::Cpu)?.reshape((3, 4))?; /// /// let c = a.matmul(&b)?; /// # Ok::<(), candle::Error>(()) @@ -314,6 +314,40 @@ impl Tensor { Self::new_impl(array, shape, device, true) } + /// Create a new 1D tensor from an iterator. + pub fn from_iter( + iter: impl IntoIterator, + device: &Device, + ) -> Result { + let data = iter.into_iter().collect::>(); + let len = data.len(); + Self::from_vec_impl(data, len, device, false) + } + + /// Create a new 1D tensor with values from the interval `[start, end)` taken with a common + /// difference `1` from `start`. + pub fn arange(start: D, end: D, device: &Device) -> Result { + Self::arange_step(start, end, D::one(), device) + } + + /// Create a new 1D tensor with values from the interval `[start, end)` taken with a common + /// difference `step` from `start`. + pub fn arange_step( + start: D, + end: D, + step: D, + device: &Device, + ) -> Result { + let mut data = vec![]; + let mut current = start; + while current < end { + data.push(current); + current += step; + } + let len = data.len(); + Self::from_vec_impl(data, len, device, false) + } + fn from_vec_impl, D: crate::WithDType>( data: Vec, shape: S, diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index ac13dfee..3e8d2b1a 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -209,7 +209,6 @@ fn main() -> Result<()> { index_pos += ctxt.len(); let next_token = if let Some(temperature) = args.temperature { - println!("Sampling with temperature {temperature:?}"); let prs = (&logits / temperature)?.softmax(D::Minus1)?; let logits_v: Vec = prs.to_vec1()?; let distr = rand::distributions::WeightedIndex::new(&logits_v)?;