mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Simple example fix.
This commit is contained in:
@ -13,6 +13,7 @@
|
|||||||
// transposition operations.
|
// transposition operations.
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
use rand::{distributions::Distribution, thread_rng};
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
use candle_hub::{api::Api, Repo, RepoType};
|
use candle_hub::{api::Api, Repo, RepoType};
|
||||||
@ -137,10 +138,7 @@ impl Embedding {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
|
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
|
||||||
Ok(Tensor::embedding(
|
Ok(Tensor::embedding(indexes, &self.embeddings)?)
|
||||||
indexes,
|
|
||||||
&self.embeddings.to_dtype(DType::F32)?,
|
|
||||||
)?)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -354,7 +352,8 @@ impl Llama {
|
|||||||
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||||
// TODO: Support for mini-batches? (i.e. r2)
|
// TODO: Support for mini-batches? (i.e. r2)
|
||||||
let t = x.shape().r1()?;
|
let t = x.shape().r1()?;
|
||||||
let mut x = self.wte.forward(x)?;
|
let x = self.wte.forward(x)?;
|
||||||
|
let mut x = x.to_dtype(DType::F32)?;
|
||||||
for block in self.blocks.iter() {
|
for block in self.blocks.iter() {
|
||||||
x = block.forward(&x, freqs_cis)?;
|
x = block.forward(&x, freqs_cis)?;
|
||||||
}
|
}
|
||||||
@ -399,8 +398,8 @@ struct Args {
|
|||||||
npy: bool,
|
npy: bool,
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
/// The temperature used to generate samples.
|
||||||
#[arg(long, default_value_t = 1.0)]
|
#[arg(long)]
|
||||||
temperature: f64,
|
temperature: Option<f64>,
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
/// The length of the sample to generate (in tokens).
|
||||||
#[arg(long, default_value_t = 100)]
|
#[arg(long, default_value_t = 100)]
|
||||||
@ -420,8 +419,34 @@ async fn main() -> Result<()> {
|
|||||||
};
|
};
|
||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
|
let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
|
||||||
let tokenizer_filename = api.get(&repo, "tokenizer.json").await?;
|
println!("building the model");
|
||||||
println!("Filename {tokenizer_filename:?}");
|
let config = Config::config_7b();
|
||||||
|
let cache = Cache::new(&device);
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let (llama, tokenizer_filename) = if args.npy {
|
||||||
|
println!("building the model (NPY)");
|
||||||
|
(
|
||||||
|
Llama::load_npy(&device, "/data/llama.npz", &cache, &config)?,
|
||||||
|
std::path::Path::new("llama-tokenizer.json").to_path_buf(),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
let tokenizer_filename = api.get(&repo, "tokenizer.json").await?;
|
||||||
|
let mut filenames = vec![];
|
||||||
|
for rfilename in [
|
||||||
|
"model-00001-of-00002.safetensors",
|
||||||
|
"model-00002-of-00002.safetensors",
|
||||||
|
] {
|
||||||
|
let filename = api.get(&repo, rfilename).await?;
|
||||||
|
filenames.push(filename);
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("building the model (SF)");
|
||||||
|
(
|
||||||
|
Llama::load(&device, &filenames, &cache, &config)?,
|
||||||
|
tokenizer_filename,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
println!("Loaded in {:?}", start.elapsed());
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
let mut tokens = tokenizer
|
let mut tokens = tokenizer
|
||||||
.encode(START_PROMPT, true)
|
.encode(START_PROMPT, true)
|
||||||
@ -429,55 +454,39 @@ async fn main() -> Result<()> {
|
|||||||
.get_ids()
|
.get_ids()
|
||||||
.to_vec();
|
.to_vec();
|
||||||
|
|
||||||
let mut filenames = vec![];
|
|
||||||
for rfilename in [
|
|
||||||
"model-00001-of-00002.safetensors",
|
|
||||||
"model-00002-of-00002.safetensors",
|
|
||||||
] {
|
|
||||||
let filename = api.get(&repo, rfilename).await?;
|
|
||||||
filenames.push(filename);
|
|
||||||
}
|
|
||||||
|
|
||||||
println!("building the model");
|
|
||||||
let config = Config::config_7b();
|
|
||||||
let cache = Cache::new(&device);
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let llama = if args.npy {
|
|
||||||
println!("building the model (NPY)");
|
|
||||||
Llama::load_npy(&device, &filenames, &cache, &config)?
|
|
||||||
} else {
|
|
||||||
println!("building the model (SF)");
|
|
||||||
Llama::load(&device, &filenames, &cache, &config)?
|
|
||||||
};
|
|
||||||
println!("Loaded in {:?}", start.elapsed());
|
|
||||||
|
|
||||||
println!("pre-computing the positional embeddings");
|
println!("pre-computing the positional embeddings");
|
||||||
let freqs_cis = precompute_freqs_cis(&config, &device)?;
|
let freqs_cis = precompute_freqs_cis(&config, &device)?;
|
||||||
println!("starting the inference loop");
|
println!("starting the inference loop");
|
||||||
let mut new_tokens = vec![];
|
let mut new_tokens = vec![];
|
||||||
//let mut rng = thread_rng();
|
let mut rng = thread_rng();
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
for index in 0..args.sample_len {
|
for index in 0..args.sample_len {
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
|
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
|
||||||
let input = Tensor::new(ctxt, &device)?;
|
let input = Tensor::new(ctxt, &device)?;
|
||||||
let logits = llama.forward(&input, &freqs_cis)?;
|
let logits = llama.forward(&input, &freqs_cis)?;
|
||||||
let prs = (&logits / args.temperature)?.softmax(logits.rank() - 1)?;
|
|
||||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
|
||||||
let next_token = logits_v
|
|
||||||
.iter()
|
|
||||||
.enumerate()
|
|
||||||
.fold((0, logits_v[0]), |(idx_max, val_max), (idx, val)| {
|
|
||||||
if &val_max > val {
|
|
||||||
(idx_max, val_max)
|
|
||||||
} else {
|
|
||||||
(idx, *val)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.0 as u32;
|
|
||||||
// let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
|
||||||
|
|
||||||
// let next_token = distr.sample(&mut rng) as u32;
|
let next_token = if let Some(temperature) = args.temperature {
|
||||||
|
println!("Sampling with temperature {temperature:?}");
|
||||||
|
let prs = (&logits / temperature)?.softmax(logits.rank() - 1)?;
|
||||||
|
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||||
|
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
||||||
|
|
||||||
|
distr.sample(&mut rng) as u32
|
||||||
|
} else {
|
||||||
|
let logits_v: Vec<f32> = logits.to_vec1()?;
|
||||||
|
logits_v
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.fold((0, logits_v[0]), |(idx_max, val_max), (idx, val)| {
|
||||||
|
if &val_max > val {
|
||||||
|
(idx_max, val_max)
|
||||||
|
} else {
|
||||||
|
(idx, *val)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.0 as u32
|
||||||
|
};
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
new_tokens.push(next_token);
|
new_tokens.push(next_token);
|
||||||
println!("> {:?}", start_gen.elapsed());
|
println!("> {:?}", start_gen.elapsed());
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
use super::*;
|
use super::*;
|
||||||
use candle::{DType, Device, Result, Shape, Tensor, WithDType};
|
use candle::{DType, Device, Result, Shape, Tensor, WithDType};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::path::PathBuf;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
@ -142,11 +141,11 @@ impl Block {
|
|||||||
impl Llama {
|
impl Llama {
|
||||||
pub fn load_npy(
|
pub fn load_npy(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
_filenames: &[PathBuf],
|
filename: &str,
|
||||||
cache: &Cache,
|
cache: &Cache,
|
||||||
config: &Config,
|
config: &Config,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let weight_path = std::path::Path::new("/data/llama.npz");
|
let weight_path = std::path::Path::new(filename);
|
||||||
let weights = if weight_path.exists() {
|
let weights = if weight_path.exists() {
|
||||||
println!("loading weights from {weight_path:?}");
|
println!("loading weights from {weight_path:?}");
|
||||||
let start_load = std::time::Instant::now();
|
let start_load = std::time::Instant::now();
|
||||||
|
Reference in New Issue
Block a user