Sketch the candle-transformers crate. (#147)

* Sketch the candle-transformers crate.

* Format the empty files.
This commit is contained in:
Laurent Mazare
2023-07-12 13:49:31 +01:00
committed by GitHub
parent eae646d322
commit ba35d895e7
12 changed files with 90 additions and 40 deletions

View File

@ -4,11 +4,11 @@
extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
use candle::{DType, Device, Tensor, D};
use candle::{DType, Device, Tensor};
use candle_hub::{api::sync::Api, Repo, RepoType};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use clap::Parser;
use rand::{distributions::Distribution, SeedableRng};
use tokenizers::Tokenizer;
mod model;
@ -21,10 +21,9 @@ const DTYPE: DType = DType::BF16;
struct TextGeneration {
model: Falcon,
rng: rand::rngs::StdRng,
device: Device,
temperature: Option<f64>,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
}
impl TextGeneration {
@ -32,14 +31,14 @@ impl TextGeneration {
model: Falcon,
tokenizer: Tokenizer,
seed: u64,
temperature: Option<f64>,
temp: Option<f64>,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp);
Self {
model,
tokenizer,
rng: rand::rngs::StdRng::seed_from_u64(seed),
temperature,
logits_processor,
device: device.clone(),
}
}
@ -67,20 +66,7 @@ impl TextGeneration {
let logits = self.model.forward(&input)?;
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
let next_token = if let Some(temperature) = self.temperature {
let prs = (&logits / temperature)?.softmax(D::Minus1)?;
let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
distr.sample(&mut self.rng) as u32
} else {
let logits_v: Vec<f32> = logits.to_vec1()?;
logits_v
.iter()
.enumerate()
.max_by(|(_, u), (_, v)| u.total_cmp(v))
.map(|(i, _)| i as u32)
.unwrap()
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
new_tokens.push(next_token);
println!("> {:?}", start_gen.elapsed());

View File

@ -14,11 +14,11 @@ extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use rand::{distributions::Distribution, SeedableRng};
use candle::{DType, Device, Tensor, D};
use candle_hub::{api::sync::Api, Repo, RepoType};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
mod model;
use model::{Config, Llama};
@ -185,8 +185,8 @@ fn main() -> Result<()> {
println!("pre-computing the positional embeddings");
let freqs_cis = precompute_freqs_cis(&config, &device)?;
println!("starting the inference loop");
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
let mut new_tokens = vec![];
let mut rng = rand::rngs::StdRng::seed_from_u64(args.seed);
let start_gen = std::time::Instant::now();
let mut index_pos = 0;
for index in 0..args.sample_len {
@ -207,21 +207,7 @@ fn main() -> Result<()> {
let logits = logits.squeeze(0)?;
index_pos += ctxt.len();
let next_token = if let Some(temperature) = args.temperature {
let prs = (&logits / temperature)?.softmax(D::Minus1)?;
let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
distr.sample(&mut rng) as u32
} else {
let logits_v: Vec<f32> = logits.to_vec1()?;
logits_v
.iter()
.enumerate()
.max_by(|(_, u), (_, v)| u.total_cmp(v))
.map(|(i, _)| i as u32)
.unwrap()
};
let next_token = logits_processor.sample(&logits)?;
tokens.push(next_token);
new_tokens.push(next_token);
println!("> {:?}", start_gen.elapsed());