Sketch the candle-transformers crate. (#147)

* Sketch the candle-transformers crate.

* Format the empty files.
This commit is contained in:
Laurent Mazare
2023-07-12 13:49:31 +01:00
committed by GitHub
parent eae646d322
commit ba35d895e7
12 changed files with 90 additions and 40 deletions

View File

@ -0,0 +1,35 @@
use candle::{DType, Error, Result, Tensor, D};
use rand::{distributions::Distribution, SeedableRng};
pub struct LogitsProcessor {
rng: rand::rngs::StdRng,
temperature: Option<f64>,
}
impl LogitsProcessor {
pub fn new(seed: u64, temperature: Option<f64>) -> Self {
Self {
rng: rand::rngs::StdRng::seed_from_u64(seed),
temperature,
}
}
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
let logits = logits.to_dtype(DType::F32)?;
let next_token = if let Some(temperature) = self.temperature {
let prs = (&logits / temperature)?.softmax(D::Minus1)?;
let prs: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?;
distr.sample(&mut self.rng) as u32
} else {
let logits_v: Vec<f32> = logits.to_vec1()?;
logits_v
.iter()
.enumerate()
.max_by(|(_, u), (_, v)| u.total_cmp(v))
.map(|(i, _)| i as u32)
.unwrap()
};
Ok(next_token)
}
}