mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Sketch the candle-transformers crate. (#147)
* Sketch the candle-transformers crate. * Format the empty files.
This commit is contained in:
@ -4,11 +4,11 @@
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{DType, Device, Tensor, D};
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_hub::{api::sync::Api, Repo, RepoType};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use clap::Parser;
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
mod model;
|
||||
@ -21,10 +21,9 @@ const DTYPE: DType = DType::BF16;
|
||||
|
||||
struct TextGeneration {
|
||||
model: Falcon,
|
||||
rng: rand::rngs::StdRng,
|
||||
device: Device,
|
||||
temperature: Option<f64>,
|
||||
tokenizer: Tokenizer,
|
||||
logits_processor: LogitsProcessor,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
@ -32,14 +31,14 @@ impl TextGeneration {
|
||||
model: Falcon,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temperature: Option<f64>,
|
||||
temp: Option<f64>,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp);
|
||||
Self {
|
||||
model,
|
||||
tokenizer,
|
||||
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
||||
temperature,
|
||||
logits_processor,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
@ -67,20 +66,7 @@ impl TextGeneration {
|
||||
let logits = self.model.forward(&input)?;
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
|
||||
let next_token = if let Some(temperature) = self.temperature {
|
||||
let prs = (&logits / temperature)?.softmax(D::Minus1)?;
|
||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
||||
distr.sample(&mut self.rng) as u32
|
||||
} else {
|
||||
let logits_v: Vec<f32> = logits.to_vec1()?;
|
||||
logits_v
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, u), (_, v)| u.total_cmp(v))
|
||||
.map(|(i, _)| i as u32)
|
||||
.unwrap()
|
||||
};
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
new_tokens.push(next_token);
|
||||
println!("> {:?}", start_gen.elapsed());
|
||||
|
Reference in New Issue
Block a user