mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Add from_iter and arange, use it in the doctests. (#145)
This commit is contained in:
@ -53,7 +53,7 @@ impl DType {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait WithDType: Sized + Copy + num_traits::NumAssign + 'static {
|
pub trait WithDType: Sized + Copy + num_traits::NumAssign + std::cmp::PartialOrd + 'static {
|
||||||
const DTYPE: DType;
|
const DTYPE: DType;
|
||||||
|
|
||||||
fn from_f64(v: f64) -> Self;
|
fn from_f64(v: f64) -> Self;
|
||||||
|
@ -5,8 +5,8 @@
|
|||||||
//! # use candle::Error;
|
//! # use candle::Error;
|
||||||
//! # fn main() -> Result<(), Error>{
|
//! # fn main() -> Result<(), Error>{
|
||||||
//!
|
//!
|
||||||
//! let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
//! let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?;
|
||||||
//! let b = Tensor::zeros((3, 4), DType::F32, &Device::Cpu)?;
|
//! let b = Tensor::arange(0f32, 12f32, &Device::Cpu)?.reshape((3, 4))?;
|
||||||
//!
|
//!
|
||||||
//! let c = a.matmul(&b)?;
|
//! let c = a.matmul(&b)?;
|
||||||
//! # Ok(())}
|
//! # Ok(())}
|
||||||
|
@ -39,8 +39,8 @@ impl AsRef<Tensor> for Tensor {
|
|||||||
/// ```rust
|
/// ```rust
|
||||||
/// use candle::{Tensor, DType, Device};
|
/// use candle::{Tensor, DType, Device};
|
||||||
///
|
///
|
||||||
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
/// let a = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((2, 3))?;
|
||||||
/// let b = Tensor::zeros((3, 4), DType::F32, &Device::Cpu)?;
|
/// let b = Tensor::arange(0f32, 12f32, &Device::Cpu)?.reshape((3, 4))?;
|
||||||
///
|
///
|
||||||
/// let c = a.matmul(&b)?;
|
/// let c = a.matmul(&b)?;
|
||||||
/// # Ok::<(), candle::Error>(())
|
/// # Ok::<(), candle::Error>(())
|
||||||
@ -314,6 +314,40 @@ impl Tensor {
|
|||||||
Self::new_impl(array, shape, device, true)
|
Self::new_impl(array, shape, device, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create a new 1D tensor from an iterator.
|
||||||
|
pub fn from_iter<D: crate::WithDType>(
|
||||||
|
iter: impl IntoIterator<Item = D>,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let data = iter.into_iter().collect::<Vec<_>>();
|
||||||
|
let len = data.len();
|
||||||
|
Self::from_vec_impl(data, len, device, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new 1D tensor with values from the interval `[start, end)` taken with a common
|
||||||
|
/// difference `1` from `start`.
|
||||||
|
pub fn arange<D: crate::WithDType>(start: D, end: D, device: &Device) -> Result<Self> {
|
||||||
|
Self::arange_step(start, end, D::one(), device)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new 1D tensor with values from the interval `[start, end)` taken with a common
|
||||||
|
/// difference `step` from `start`.
|
||||||
|
pub fn arange_step<D: crate::WithDType>(
|
||||||
|
start: D,
|
||||||
|
end: D,
|
||||||
|
step: D,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let mut data = vec![];
|
||||||
|
let mut current = start;
|
||||||
|
while current < end {
|
||||||
|
data.push(current);
|
||||||
|
current += step;
|
||||||
|
}
|
||||||
|
let len = data.len();
|
||||||
|
Self::from_vec_impl(data, len, device, false)
|
||||||
|
}
|
||||||
|
|
||||||
fn from_vec_impl<S: Into<Shape>, D: crate::WithDType>(
|
fn from_vec_impl<S: Into<Shape>, D: crate::WithDType>(
|
||||||
data: Vec<D>,
|
data: Vec<D>,
|
||||||
shape: S,
|
shape: S,
|
||||||
|
@ -209,7 +209,6 @@ fn main() -> Result<()> {
|
|||||||
index_pos += ctxt.len();
|
index_pos += ctxt.len();
|
||||||
|
|
||||||
let next_token = if let Some(temperature) = args.temperature {
|
let next_token = if let Some(temperature) = args.temperature {
|
||||||
println!("Sampling with temperature {temperature:?}");
|
|
||||||
let prs = (&logits / temperature)?.softmax(D::Minus1)?;
|
let prs = (&logits / temperature)?.softmax(D::Minus1)?;
|
||||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||||
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
||||||
|
Reference in New Issue
Block a user