Add from_iter and arange, use it in the doctests. (#145)

This commit is contained in:
Laurent Mazare
2023-07-12 12:03:01 +01:00
committed by GitHub
parent b3b39cca92
commit 20599172ac
4 changed files with 39 additions and 6 deletions

View File

@ -53,7 +53,7 @@ impl DType {
}
}
pub trait WithDType: Sized + Copy + num_traits::NumAssign + 'static {
pub trait WithDType: Sized + Copy + num_traits::NumAssign + std::cmp::PartialOrd + 'static {
const DTYPE: DType;
fn from_f64(v: f64) -> Self;

View File

@ -5,8 +5,8 @@
//! # use candle::Error;
//! # fn main() -> Result<(), Error>{
//!
//! let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
//! let b = Tensor::zeros((3, 4), DType::F32, &Device::Cpu)?;
//! let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?;
//! let b = Tensor::arange(0f32, 12f32, &Device::Cpu)?.reshape((3, 4))?;
//!
//! let c = a.matmul(&b)?;
//! # Ok(())}

View File

@ -39,8 +39,8 @@ impl AsRef<Tensor> for Tensor {
/// ```rust
/// use candle::{Tensor, DType, Device};
///
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
/// let b = Tensor::zeros((3, 4), DType::F32, &Device::Cpu)?;
/// let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?;
/// let b = Tensor::arange(0f32, 12f32, &Device::Cpu)?.reshape((3, 4))?;
///
/// let c = a.matmul(&b)?;
/// # Ok::<(), candle::Error>(())
@ -314,6 +314,40 @@ impl Tensor {
Self::new_impl(array, shape, device, true)
}
/// Create a new 1D tensor from an iterator.
pub fn from_iter<D: crate::WithDType>(
iter: impl IntoIterator<Item = D>,
device: &Device,
) -> Result<Self> {
let data = iter.into_iter().collect::<Vec<_>>();
let len = data.len();
Self::from_vec_impl(data, len, device, false)
}
/// Create a new 1D tensor with values from the interval `[start, end)` taken with a common
/// difference `1` from `start`.
pub fn arange<D: crate::WithDType>(start: D, end: D, device: &Device) -> Result<Self> {
Self::arange_step(start, end, D::one(), device)
}
/// Create a new 1D tensor with values from the interval `[start, end)` taken with a common
/// difference `step` from `start`.
pub fn arange_step<D: crate::WithDType>(
start: D,
end: D,
step: D,
device: &Device,
) -> Result<Self> {
let mut data = vec![];
let mut current = start;
while current < end {
data.push(current);
current += step;
}
let len = data.len();
Self::from_vec_impl(data, len, device, false)
}
fn from_vec_impl<S: Into<Shape>, D: crate::WithDType>(
data: Vec<D>,
shape: S,

View File

@ -209,7 +209,6 @@ fn main() -> Result<()> {
index_pos += ctxt.len();
let next_token = if let Some(temperature) = args.temperature {
println!("Sampling with temperature {temperature:?}");
let prs = (&logits / temperature)?.softmax(D::Minus1)?;
let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;