mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Refactor the llama example to make it more in sync with the other ones. (#139)
* Refactor the llama example to make it more in sync with the other ones. * Make clippy happy. * Properly load the safetensor weights. * Get llama back to a working state for the safetensors case.
This commit is contained in:
@ -20,11 +20,10 @@ use rand::{distributions::Distribution, SeedableRng};
|
|||||||
|
|
||||||
use candle::{DType, Device, Tensor, D};
|
use candle::{DType, Device, Tensor, D};
|
||||||
use candle_hub::{api::sync::Api, Repo, RepoType};
|
use candle_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use std::collections::HashMap;
|
use candle_nn::VarBuilder;
|
||||||
use std::sync::{Arc, Mutex};
|
|
||||||
|
|
||||||
mod var_store;
|
mod model;
|
||||||
mod weights;
|
use model::{Config, Llama};
|
||||||
|
|
||||||
const MAX_SEQ_LEN: usize = 4096;
|
const MAX_SEQ_LEN: usize = 4096;
|
||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
@ -83,337 +82,6 @@ Whate'er it bodes, henceforward will I bear
|
|||||||
Upon my target three fair-shining suns.
|
Upon my target three fair-shining suns.
|
||||||
";
|
";
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
struct Config {
|
|
||||||
block_size: usize,
|
|
||||||
vocab_size: usize,
|
|
||||||
n_layer: usize,
|
|
||||||
n_head: usize,
|
|
||||||
n_embd: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
impl Config {
|
|
||||||
fn config_7b() -> Self {
|
|
||||||
Self {
|
|
||||||
block_size: 4096,
|
|
||||||
vocab_size: 32000,
|
|
||||||
n_layer: 32,
|
|
||||||
n_head: 32,
|
|
||||||
n_embd: 4096,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn config_13b() -> Self {
|
|
||||||
Self {
|
|
||||||
block_size: 4096,
|
|
||||||
vocab_size: 32000,
|
|
||||||
n_layer: 40,
|
|
||||||
n_head: 40,
|
|
||||||
n_embd: 5120,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn config_30b() -> Self {
|
|
||||||
Self {
|
|
||||||
block_size: 4096,
|
|
||||||
vocab_size: 32000,
|
|
||||||
n_layer: 60,
|
|
||||||
n_head: 52,
|
|
||||||
n_embd: 6656,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn config_65b() -> Self {
|
|
||||||
Self {
|
|
||||||
block_size: 4096,
|
|
||||||
vocab_size: 32000,
|
|
||||||
n_layer: 80,
|
|
||||||
n_head: 64,
|
|
||||||
n_embd: 8192,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Embedding {
|
|
||||||
embeddings: Tensor,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Embedding {
|
|
||||||
fn new(embeddings: Tensor) -> Self {
|
|
||||||
Self { embeddings }
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
|
|
||||||
let embeddings = self.embeddings.to_dtype(DTYPE)?;
|
|
||||||
Ok(Tensor::embedding(indexes, &embeddings)?)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Linear {
|
|
||||||
weight: Tensor,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Linear {
|
|
||||||
fn new(weight: Tensor) -> Self {
|
|
||||||
Self { weight }
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
|
||||||
let weight = self.weight.to_dtype(DTYPE)?;
|
|
||||||
let x = x.matmul(&weight.t()?)?;
|
|
||||||
Ok(x)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct RmsNorm {
|
|
||||||
scale: Tensor,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RmsNorm {
|
|
||||||
fn new(scale: Tensor) -> Self {
|
|
||||||
Self { scale }
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
|
||||||
// This is a no-op if x's dtype is already f32.
|
|
||||||
let x = x.to_dtype(DType::F32)?;
|
|
||||||
let (seq_len, hidden_size) = x.shape().r2()?;
|
|
||||||
let norm_x = ((&x * &x)?.sum(&[1])? / hidden_size as f64)?;
|
|
||||||
let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
|
|
||||||
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
|
|
||||||
let size = self.scale.shape().r1()?;
|
|
||||||
let scale = self
|
|
||||||
.scale
|
|
||||||
.to_dtype(DType::F32)?
|
|
||||||
.broadcast_as((seq_len, size))?;
|
|
||||||
let x = (scale * x_normed)?;
|
|
||||||
let x = x.to_dtype(DTYPE)?;
|
|
||||||
Ok(x)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Mlp {
|
|
||||||
c_fc1: Linear,
|
|
||||||
c_fc2: Linear,
|
|
||||||
c_proj: Linear,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn silu(xs: &Tensor) -> Result<Tensor> {
|
|
||||||
Ok((xs / (xs.neg()?.exp()? + 1.0)?)?)
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Mlp {
|
|
||||||
fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self {
|
|
||||||
Self {
|
|
||||||
c_fc1,
|
|
||||||
c_fc2,
|
|
||||||
c_proj,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
|
||||||
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
|
|
||||||
self.c_proj.forward(&x)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
|
||||||
let shape = mask.shape();
|
|
||||||
let on_true = Tensor::new(on_true, &on_false.device())?.broadcast_as(shape.dims())?;
|
|
||||||
let m = mask.where_cond(&on_true, on_false)?;
|
|
||||||
Ok(m)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
struct Cache {
|
|
||||||
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
|
||||||
use_kv_cache: bool,
|
|
||||||
#[allow(clippy::type_complexity)]
|
|
||||||
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
|
|
||||||
device: Device,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Cache {
|
|
||||||
fn new(use_kv_cache: bool, config: &Config, device: &Device) -> Self {
|
|
||||||
Self {
|
|
||||||
masks: Arc::new(Mutex::new(HashMap::new())),
|
|
||||||
use_kv_cache,
|
|
||||||
kvs: Arc::new(Mutex::new(vec![None; config.n_layer])),
|
|
||||||
device: device.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn mask(&self, t: usize) -> Result<Tensor> {
|
|
||||||
let mut masks = self.masks.lock().unwrap();
|
|
||||||
if let Some(mask) = masks.get(&t) {
|
|
||||||
Ok(mask.clone())
|
|
||||||
} else {
|
|
||||||
// TODO: If we support bool or u8 tensors, this would be better.
|
|
||||||
let mask: Vec<_> = (0..t)
|
|
||||||
.flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
|
|
||||||
.collect();
|
|
||||||
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
|
|
||||||
masks.insert(t, mask.clone());
|
|
||||||
Ok(mask)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct CausalSelfAttention {
|
|
||||||
c_attn: Linear,
|
|
||||||
c_proj: Linear,
|
|
||||||
n_head: usize,
|
|
||||||
cache: Cache,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl CausalSelfAttention {
|
|
||||||
fn new(c_attn: Linear, c_proj: Linear, n_head: usize, cache: &Cache) -> Self {
|
|
||||||
Self {
|
|
||||||
c_attn,
|
|
||||||
c_proj,
|
|
||||||
n_head,
|
|
||||||
cache: cache.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
|
||||||
let mut dims = x.dims().to_vec();
|
|
||||||
let fcis_dims = freqs_cis.dims();
|
|
||||||
let freqs_cis = if dims[1] < fcis_dims[1] {
|
|
||||||
freqs_cis.narrow(1, 0, dims[1])?
|
|
||||||
} else {
|
|
||||||
freqs_cis.clone()
|
|
||||||
};
|
|
||||||
let v = dims.pop().unwrap();
|
|
||||||
dims.push(v / 2);
|
|
||||||
dims.push(2);
|
|
||||||
let x = x.reshape(dims)?;
|
|
||||||
let re_x = x.narrow(D::Minus1, 0, 1)?;
|
|
||||||
let im_x = x.narrow(D::Minus1, 1, 1)?;
|
|
||||||
let re_f = freqs_cis
|
|
||||||
.narrow(D::Minus1, 0, 1)?
|
|
||||||
.broadcast_as(re_x.shape())?;
|
|
||||||
let im_f = freqs_cis
|
|
||||||
.narrow(D::Minus1, 1, 1)?
|
|
||||||
.broadcast_as(im_x.shape())?;
|
|
||||||
let re = ((&re_x * &re_f)? - (&im_x * &im_f)?)?;
|
|
||||||
let im = ((&re_x * &im_f)? + (&im_x * &re_f)?)?;
|
|
||||||
let rope = Tensor::cat(&[&re, &im], D::Minus1)?;
|
|
||||||
let rope = rope.flatten_from(D::Minus2)?;
|
|
||||||
Ok(rope)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result<Tensor> {
|
|
||||||
let (t, c) = x.shape().r2()?;
|
|
||||||
let qkv = self.c_attn.forward(x)?;
|
|
||||||
let qkv = qkv.to_dtype(DType::F32)?;
|
|
||||||
let n_embd = c;
|
|
||||||
let q = qkv.narrow(1, 0, n_embd)?;
|
|
||||||
let k = qkv.narrow(1, n_embd, n_embd)?;
|
|
||||||
let v = qkv.narrow(1, 2 * n_embd, n_embd)?;
|
|
||||||
let target_dim = [t, self.n_head, c / self.n_head];
|
|
||||||
let k = k.reshape(target_dim.as_slice())?.transpose(0, 1)?;
|
|
||||||
let q = q.reshape(target_dim.as_slice())?.transpose(0, 1)?;
|
|
||||||
let mut v = v.reshape(target_dim.as_slice())?.transpose(0, 1)?;
|
|
||||||
let q = self.apply_rotary_emb(&q, freqs_cis)?;
|
|
||||||
let mut k = self.apply_rotary_emb(&k, freqs_cis)?;
|
|
||||||
|
|
||||||
if self.cache.use_kv_cache {
|
|
||||||
let mut cache = self.cache.kvs.lock().unwrap();
|
|
||||||
if let Some((cache_k, cache_v)) = &cache[block_idx] {
|
|
||||||
k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
|
|
||||||
v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?;
|
|
||||||
let k_seq_len = k.dims()[1];
|
|
||||||
if k_seq_len > MAX_SEQ_LEN {
|
|
||||||
k = k
|
|
||||||
.narrow(1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
|
|
||||||
.contiguous()?
|
|
||||||
}
|
|
||||||
let v_seq_len = v.dims()[1];
|
|
||||||
if v_seq_len > 2 * MAX_SEQ_LEN {
|
|
||||||
v = v
|
|
||||||
.narrow(1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
|
|
||||||
.contiguous()?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
cache[block_idx] = Some((k.clone(), v.clone()))
|
|
||||||
}
|
|
||||||
|
|
||||||
let att = (q.matmul(&k.t()?)? / (k.dim(D::Minus1)? as f64).sqrt())?;
|
|
||||||
let mask = self.cache.mask(t)?.broadcast_as(att.shape())?;
|
|
||||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
|
||||||
let att = att.softmax(D::Minus1)?;
|
|
||||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
|
||||||
let y = att.matmul(&v.contiguous()?)?;
|
|
||||||
let y = y.transpose(0, 1)?.reshape(&[t, c])?;
|
|
||||||
let y = y.to_dtype(DTYPE)?;
|
|
||||||
let y = self.c_proj.forward(&y)?;
|
|
||||||
Ok(y)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Block {
|
|
||||||
rms_1: RmsNorm,
|
|
||||||
attn: CausalSelfAttention,
|
|
||||||
rms_2: RmsNorm,
|
|
||||||
mlp: Mlp,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Block {
|
|
||||||
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
|
|
||||||
Self {
|
|
||||||
rms_1,
|
|
||||||
attn,
|
|
||||||
rms_2,
|
|
||||||
mlp,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result<Tensor> {
|
|
||||||
let x = (self
|
|
||||||
.attn
|
|
||||||
.forward(&self.rms_1.forward(x)?, freqs_cis, block_idx)?
|
|
||||||
+ x)?;
|
|
||||||
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + x)?;
|
|
||||||
Ok(x)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Llama {
|
|
||||||
wte: Embedding,
|
|
||||||
blocks: Vec<Block>,
|
|
||||||
ln_f: RmsNorm,
|
|
||||||
lm_head: Linear,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Llama {
|
|
||||||
fn new(wte: Embedding, blocks: Vec<Block>, ln_f: RmsNorm, lm_head: Linear) -> Self {
|
|
||||||
Self {
|
|
||||||
wte,
|
|
||||||
blocks,
|
|
||||||
ln_f,
|
|
||||||
lm_head,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
|
||||||
// TODO: Support for mini-batches? (i.e. r2)
|
|
||||||
let t = x.shape().r1()?;
|
|
||||||
let mut x = self.wte.forward(x)?;
|
|
||||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
|
||||||
x = block.forward(&x, freqs_cis, block_idx)?;
|
|
||||||
}
|
|
||||||
let x = self.ln_f.forward(&x)?;
|
|
||||||
let x = x.narrow(0, t - 1, 1)?;
|
|
||||||
let logits = self.lm_head.forward(&x)?;
|
|
||||||
let logits = logits.to_dtype(DType::F32)?;
|
|
||||||
let (b, vocab_size) = logits.shape().r2()?;
|
|
||||||
assert_eq!(b, 1);
|
|
||||||
Ok(logits.reshape(vocab_size)?)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn precompute_freqs_cis(config: &Config, device: &Device) -> Result<Tensor> {
|
fn precompute_freqs_cis(config: &Config, device: &Device) -> Result<Tensor> {
|
||||||
let n_elem = config.n_embd / config.n_head;
|
let n_elem = config.n_embd / config.n_head;
|
||||||
let theta: Vec<_> = (0..n_elem)
|
let theta: Vec<_> = (0..n_elem)
|
||||||
@ -474,19 +142,15 @@ fn main() -> Result<()> {
|
|||||||
Device::new_cuda(0)?
|
Device::new_cuda(0)?
|
||||||
};
|
};
|
||||||
let config = Config::config_7b();
|
let config = Config::config_7b();
|
||||||
let cache = Cache::new(!args.no_kv_cache, &config, &device);
|
let cache = model::Cache::new(!args.no_kv_cache, &config, &device);
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let (llama, tokenizer_filename) = match args.npy {
|
let (llama, tokenizer_filename) = match args.npy {
|
||||||
Some(npy) => {
|
Some(_) => {
|
||||||
println!("building the model (NPY)");
|
todo!("fix numpy handling if we continue supporting it")
|
||||||
let weights = Llama::load_npy(&device, &npy, &cache, &config)?;
|
|
||||||
let token_path = std::path::Path::new("llama-tokenizer.json").to_path_buf();
|
|
||||||
(weights, token_path)
|
|
||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
|
let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
|
||||||
println!("building the model");
|
println!("loading the model weights");
|
||||||
let tokenizer_filename = api.get(&repo, "tokenizer.json")?;
|
let tokenizer_filename = api.get(&repo, "tokenizer.json")?;
|
||||||
let mut filenames = vec![];
|
let mut filenames = vec![];
|
||||||
for rfilename in [
|
for rfilename in [
|
||||||
@ -497,14 +161,20 @@ fn main() -> Result<()> {
|
|||||||
filenames.push(filename);
|
filenames.push(filename);
|
||||||
}
|
}
|
||||||
|
|
||||||
println!("building the model (SF)");
|
println!("building the model");
|
||||||
(
|
let handles = filenames
|
||||||
Llama::load(&device, &filenames, &cache, &config)?,
|
.iter()
|
||||||
tokenizer_filename,
|
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f.as_path())? }))
|
||||||
)
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
let tensors: Vec<_> = handles
|
||||||
|
.iter()
|
||||||
|
.map(|h| Ok(h.deserialize()?))
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
|
||||||
|
let vb = VarBuilder::from_safetensors(tensors, DTYPE, &device);
|
||||||
|
(Llama::load(vb, &cache, &config)?, tokenizer_filename)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
println!("Loaded in {:?}", start.elapsed());
|
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
|
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
|
||||||
let mut tokens = tokenizer
|
let mut tokens = tokenizer
|
||||||
|
364
candle-examples/examples/llama/model.rs
Normal file
364
candle-examples/examples/llama/model.rs
Normal file
@ -0,0 +1,364 @@
|
|||||||
|
use candle::{DType, Device, Result, Tensor, D};
|
||||||
|
use candle_nn::{Linear, VarBuilder};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
use super::MAX_SEQ_LEN;
|
||||||
|
|
||||||
|
pub struct Config {
|
||||||
|
pub hidden_size: usize,
|
||||||
|
pub intermediate_size: usize,
|
||||||
|
pub vocab_size: usize,
|
||||||
|
pub n_layer: usize,
|
||||||
|
pub n_head: usize,
|
||||||
|
pub n_embd: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
pub fn config_7b() -> Self {
|
||||||
|
Self {
|
||||||
|
hidden_size: 4096,
|
||||||
|
intermediate_size: 11008,
|
||||||
|
vocab_size: 32000,
|
||||||
|
n_layer: 32,
|
||||||
|
n_head: 32,
|
||||||
|
n_embd: 4096,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Cache {
|
||||||
|
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
||||||
|
pub use_kv_cache: bool,
|
||||||
|
#[allow(clippy::type_complexity)]
|
||||||
|
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
|
||||||
|
device: Device,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Cache {
|
||||||
|
pub fn new(use_kv_cache: bool, config: &Config, device: &Device) -> Self {
|
||||||
|
Self {
|
||||||
|
masks: Arc::new(Mutex::new(HashMap::new())),
|
||||||
|
use_kv_cache,
|
||||||
|
kvs: Arc::new(Mutex::new(vec![None; config.n_layer])),
|
||||||
|
device: device.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn mask(&self, t: usize) -> Result<Tensor> {
|
||||||
|
let mut masks = self.masks.lock().unwrap();
|
||||||
|
if let Some(mask) = masks.get(&t) {
|
||||||
|
Ok(mask.clone())
|
||||||
|
} else {
|
||||||
|
// TODO: If we support bool or u8 tensors, this would be better.
|
||||||
|
let mask: Vec<_> = (0..t)
|
||||||
|
.flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
|
||||||
|
.collect();
|
||||||
|
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
|
||||||
|
masks.insert(t, mask.clone());
|
||||||
|
Ok(mask)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn silu(xs: &Tensor) -> Result<Tensor> {
|
||||||
|
xs / (xs.neg()?.exp()? + 1.0)?
|
||||||
|
}
|
||||||
|
|
||||||
|
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||||
|
let weight = vb.get((size2, size1), "weight")?;
|
||||||
|
Ok(Linear::new(weight, None))
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Embedding {
|
||||||
|
embeddings: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Embedding {
|
||||||
|
fn new(embeddings: Tensor) -> Self {
|
||||||
|
Self { embeddings }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?;
|
||||||
|
Ok(Self::new(embeddings))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
|
||||||
|
Tensor::embedding(indexes, &self.embeddings)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct RmsNorm {
|
||||||
|
scale: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RmsNorm {
|
||||||
|
fn load(size: usize, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let scale = vb.get(size, "weight")?;
|
||||||
|
Ok(Self::new(scale))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new(scale: Tensor) -> Self {
|
||||||
|
Self { scale }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let in_dtype = x.dtype();
|
||||||
|
// This is a no-op if x's dtype is already f32.
|
||||||
|
let x = x.to_dtype(DType::F32)?;
|
||||||
|
let (seq_len, hidden_size) = x.shape().r2()?;
|
||||||
|
let norm_x = ((&x * &x)?.sum(&[1])? / hidden_size as f64)?;
|
||||||
|
let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
|
||||||
|
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
|
||||||
|
let size = self.scale.shape().r1()?;
|
||||||
|
let scale = self
|
||||||
|
.scale
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.broadcast_as((seq_len, size))?;
|
||||||
|
let x = (scale * x_normed)?;
|
||||||
|
let x = x.to_dtype(in_dtype)?;
|
||||||
|
Ok(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct CausalSelfAttention {
|
||||||
|
c_attn: Linear,
|
||||||
|
c_proj: Linear,
|
||||||
|
n_head: usize,
|
||||||
|
cache: Cache,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CausalSelfAttention {
|
||||||
|
fn new(c_attn: Linear, c_proj: Linear, n_head: usize, cache: &Cache) -> Self {
|
||||||
|
Self {
|
||||||
|
c_attn,
|
||||||
|
c_proj,
|
||||||
|
n_head,
|
||||||
|
cache: cache.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||||
|
let mut dims = x.dims().to_vec();
|
||||||
|
let fcis_dims = freqs_cis.dims();
|
||||||
|
let freqs_cis = if dims[1] < fcis_dims[1] {
|
||||||
|
freqs_cis.narrow(1, 0, dims[1])?
|
||||||
|
} else {
|
||||||
|
freqs_cis.clone()
|
||||||
|
};
|
||||||
|
let v = dims.pop().unwrap();
|
||||||
|
dims.push(v / 2);
|
||||||
|
dims.push(2);
|
||||||
|
let x = x.reshape(dims)?;
|
||||||
|
let re_x = x.narrow(D::Minus1, 0, 1)?;
|
||||||
|
let im_x = x.narrow(D::Minus1, 1, 1)?;
|
||||||
|
let re_f = freqs_cis
|
||||||
|
.narrow(D::Minus1, 0, 1)?
|
||||||
|
.broadcast_as(re_x.shape())?;
|
||||||
|
let im_f = freqs_cis
|
||||||
|
.narrow(D::Minus1, 1, 1)?
|
||||||
|
.broadcast_as(im_x.shape())?;
|
||||||
|
let re = ((&re_x * &re_f)? - (&im_x * &im_f)?)?;
|
||||||
|
let im = ((&re_x * &im_f)? + (&im_x * &re_f)?)?;
|
||||||
|
let rope = Tensor::cat(&[&re, &im], D::Minus1)?;
|
||||||
|
let rope = rope.flatten_from(D::Minus2)?;
|
||||||
|
Ok(rope)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result<Tensor> {
|
||||||
|
let x_dtype = x.dtype();
|
||||||
|
let (t, c) = x.shape().r2()?;
|
||||||
|
let qkv = self.c_attn.forward(x)?;
|
||||||
|
let qkv = qkv.to_dtype(DType::F32)?;
|
||||||
|
let n_embd = c;
|
||||||
|
let q = qkv.narrow(1, 0, n_embd)?;
|
||||||
|
let k = qkv.narrow(1, n_embd, n_embd)?;
|
||||||
|
let v = qkv.narrow(1, 2 * n_embd, n_embd)?;
|
||||||
|
let target_dim = [t, self.n_head, c / self.n_head];
|
||||||
|
let k = k.reshape(target_dim.as_slice())?.transpose(0, 1)?;
|
||||||
|
let q = q.reshape(target_dim.as_slice())?.transpose(0, 1)?;
|
||||||
|
let mut v = v.reshape(target_dim.as_slice())?.transpose(0, 1)?;
|
||||||
|
let q = self.apply_rotary_emb(&q, freqs_cis)?;
|
||||||
|
let mut k = self.apply_rotary_emb(&k, freqs_cis)?;
|
||||||
|
|
||||||
|
if self.cache.use_kv_cache {
|
||||||
|
let mut cache = self.cache.kvs.lock().unwrap();
|
||||||
|
if let Some((cache_k, cache_v)) = &cache[block_idx] {
|
||||||
|
k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
|
||||||
|
v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?;
|
||||||
|
let k_seq_len = k.dims()[1];
|
||||||
|
if k_seq_len > MAX_SEQ_LEN {
|
||||||
|
k = k
|
||||||
|
.narrow(1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
|
||||||
|
.contiguous()?
|
||||||
|
}
|
||||||
|
let v_seq_len = v.dims()[1];
|
||||||
|
if v_seq_len > 2 * MAX_SEQ_LEN {
|
||||||
|
v = v
|
||||||
|
.narrow(1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
|
||||||
|
.contiguous()?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cache[block_idx] = Some((k.clone(), v.clone()))
|
||||||
|
}
|
||||||
|
|
||||||
|
let att = (q.matmul(&k.t()?)? / (k.dim(D::Minus1)? as f64).sqrt())?;
|
||||||
|
let mask = self.cache.mask(t)?.broadcast_as(att.shape())?;
|
||||||
|
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||||
|
let att = att.softmax(D::Minus1)?;
|
||||||
|
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||||
|
let y = att.matmul(&v.contiguous()?)?;
|
||||||
|
let y = y.transpose(0, 1)?.reshape(&[t, c])?;
|
||||||
|
let y = y.to_dtype(x_dtype)?;
|
||||||
|
let y = self.c_proj.forward(&y)?;
|
||||||
|
Ok(y)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||||
|
let size_in = cfg.hidden_size;
|
||||||
|
let size = (cfg.hidden_size / cfg.n_head) * cfg.n_head;
|
||||||
|
let q_proj = vb.get((size_in, size), "q_proj.weight")?;
|
||||||
|
let k_proj = vb.get((size_in, size), "k_proj.weight")?;
|
||||||
|
let v_proj = vb.get((size_in, size), "v_proj.weight")?;
|
||||||
|
// Invert the transformation from:
|
||||||
|
// https://github.com/huggingface/transformers/blob/2642d8d04b14c18199ebe7b35f976da02df61752/src/transformers/models/llama/convert_llama_weights_to_hf.py#L101
|
||||||
|
let n_head = cfg.n_head;
|
||||||
|
let q_proj = q_proj
|
||||||
|
.reshape((n_head, 2, size / n_head / 2, size_in))?
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.reshape((size_in, size))?;
|
||||||
|
let k_proj = k_proj
|
||||||
|
.reshape((n_head, 2, size / n_head / 2, size_in))?
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.reshape((size_in, size))?;
|
||||||
|
let attn_weight = Tensor::cat(&[q_proj, k_proj, v_proj], 0)?;
|
||||||
|
let c_attn = Linear::new(attn_weight, None);
|
||||||
|
let o_proj = linear(size, size_in, vb.pp("o_proj"))?;
|
||||||
|
Ok(Self::new(c_attn, o_proj, cfg.n_head, cache))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||||
|
let shape = mask.shape();
|
||||||
|
let on_true = Tensor::new(on_true, &on_false.device())?.broadcast_as(shape.dims())?;
|
||||||
|
let m = mask.where_cond(&on_true, on_false)?;
|
||||||
|
Ok(m)
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Mlp {
|
||||||
|
c_fc1: Linear,
|
||||||
|
c_fc2: Linear,
|
||||||
|
c_proj: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Mlp {
|
||||||
|
fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self {
|
||||||
|
Self {
|
||||||
|
c_fc1,
|
||||||
|
c_fc2,
|
||||||
|
c_proj,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
|
||||||
|
self.c_proj.forward(&x)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
|
let h_size = cfg.hidden_size;
|
||||||
|
let i_size = cfg.intermediate_size;
|
||||||
|
let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?;
|
||||||
|
let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?;
|
||||||
|
let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?;
|
||||||
|
Ok(Self::new(c_fc1, c_fc2, c_proj))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Block {
|
||||||
|
rms_1: RmsNorm,
|
||||||
|
attn: CausalSelfAttention,
|
||||||
|
rms_2: RmsNorm,
|
||||||
|
mlp: Mlp,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Block {
|
||||||
|
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
|
||||||
|
Self {
|
||||||
|
rms_1,
|
||||||
|
attn,
|
||||||
|
rms_2,
|
||||||
|
mlp,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result<Tensor> {
|
||||||
|
let x = (self
|
||||||
|
.attn
|
||||||
|
.forward(&self.rms_1.forward(x)?, freqs_cis, block_idx)?
|
||||||
|
+ x)?;
|
||||||
|
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + x)?;
|
||||||
|
Ok(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||||
|
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
|
||||||
|
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
|
||||||
|
let input_layernorm = RmsNorm::load(cfg.hidden_size, vb.pp("input_layernorm"))?;
|
||||||
|
let post_attention_layernorm =
|
||||||
|
RmsNorm::load(cfg.hidden_size, vb.pp("post_attention_layernorm"))?;
|
||||||
|
Ok(Self::new(
|
||||||
|
input_layernorm,
|
||||||
|
attn,
|
||||||
|
post_attention_layernorm,
|
||||||
|
mlp,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Llama {
|
||||||
|
wte: Embedding,
|
||||||
|
blocks: Vec<Block>,
|
||||||
|
ln_f: RmsNorm,
|
||||||
|
lm_head: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Llama {
|
||||||
|
fn new(wte: Embedding, blocks: Vec<Block>, ln_f: RmsNorm, lm_head: Linear) -> Self {
|
||||||
|
Self {
|
||||||
|
wte,
|
||||||
|
blocks,
|
||||||
|
ln_f,
|
||||||
|
lm_head,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||||
|
// TODO: Support for mini-batches? (i.e. r2)
|
||||||
|
let t = x.shape().r1()?;
|
||||||
|
let mut x = self.wte.forward(x)?;
|
||||||
|
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||||
|
x = block.forward(&x, freqs_cis, block_idx)?;
|
||||||
|
}
|
||||||
|
let x = self.ln_f.forward(&x)?;
|
||||||
|
let x = x.narrow(0, t - 1, 1)?;
|
||||||
|
let logits = self.lm_head.forward(&x)?;
|
||||||
|
let logits = logits.to_dtype(DType::F32)?;
|
||||||
|
let (b, vocab_size) = logits.shape().r2()?;
|
||||||
|
assert_eq!(b, 1);
|
||||||
|
logits.reshape(vocab_size)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||||
|
let wte = Embedding::load(cfg, vb.pp("model.embed_tokens"))?;
|
||||||
|
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||||
|
let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?;
|
||||||
|
let blocks: Vec<_> = (0..cfg.n_layer)
|
||||||
|
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(Self::new(wte, blocks, norm, lm_head))
|
||||||
|
}
|
||||||
|
}
|
@ -1,137 +0,0 @@
|
|||||||
use super::*;
|
|
||||||
use candle::{Device, Result, Tensor};
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::sync::{Arc, Mutex};
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct VarBuilder {
|
|
||||||
path: Vec<String>,
|
|
||||||
default_device: Device,
|
|
||||||
tensors: Arc<Mutex<HashMap<String, Tensor>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl VarBuilder {
|
|
||||||
pub fn new(device: &Device, tensors: HashMap<String, Tensor>) -> Self {
|
|
||||||
Self {
|
|
||||||
path: vec![],
|
|
||||||
tensors: Arc::new(Mutex::new(tensors)),
|
|
||||||
default_device: device.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_and_remove(&self, s: &str) -> Result<Tensor> {
|
|
||||||
let path = format!("{}.{s}", self.path.join("."));
|
|
||||||
let mut tensors = self.tensors.as_ref().lock().unwrap();
|
|
||||||
let parameter = match tensors.remove(&path) {
|
|
||||||
Some(tensor) => tensor.to_device(&self.default_device)?,
|
|
||||||
None => panic!("cannot find tensor for {path}"),
|
|
||||||
};
|
|
||||||
Ok(parameter)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S: ToString> std::ops::Div<S> for &VarBuilder {
|
|
||||||
type Output = VarBuilder;
|
|
||||||
|
|
||||||
fn div(self, rhs: S) -> VarBuilder {
|
|
||||||
let mut path = self.path.clone();
|
|
||||||
path.push(rhs.to_string());
|
|
||||||
VarBuilder {
|
|
||||||
path,
|
|
||||||
default_device: self.default_device.clone(),
|
|
||||||
tensors: self.tensors.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S: ToString> std::ops::Div<S> for VarBuilder {
|
|
||||||
type Output = VarBuilder;
|
|
||||||
|
|
||||||
fn div(self, rhs: S) -> VarBuilder {
|
|
||||||
&self / rhs
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Embedding {
|
|
||||||
fn load_npy(vb: VarBuilder) -> Result<Self> {
|
|
||||||
let embeddings = vb.get_and_remove("weight")?.to_dtype(DTYPE)?;
|
|
||||||
Ok(Self { embeddings })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Linear {
|
|
||||||
fn load_npy(vb: VarBuilder) -> Result<Self> {
|
|
||||||
let weight = vb.get_and_remove("weight")?.to_dtype(DTYPE)?.t()?;
|
|
||||||
Ok(Self { weight })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RmsNorm {
|
|
||||||
fn load_npy(vb: VarBuilder) -> Result<Self> {
|
|
||||||
let scale = vb.get_and_remove("scale")?.to_dtype(DTYPE)?;
|
|
||||||
Ok(Self::new(scale))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl CausalSelfAttention {
|
|
||||||
fn load_npy(vb: VarBuilder, cache: &Cache, config: &Config) -> Result<Self> {
|
|
||||||
let c_attn = Linear::load_npy(&vb / "c_attn")?;
|
|
||||||
let c_proj = Linear::load_npy(&vb / "c_proj")?;
|
|
||||||
Ok(Self::new(c_attn, c_proj, config.n_head, cache))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Mlp {
|
|
||||||
fn load_npy(vb: VarBuilder) -> Result<Self> {
|
|
||||||
let c_fc1 = Linear::load_npy(&vb / "c_fc1")?;
|
|
||||||
let c_fc2 = Linear::load_npy(&vb / "c_fc2")?;
|
|
||||||
let c_proj = Linear::load_npy(&vb / "c_proj")?;
|
|
||||||
Ok(Self::new(c_fc1, c_fc2, c_proj))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Block {
|
|
||||||
fn load_npy(vb: VarBuilder, cache: &Cache, config: &Config) -> Result<Self> {
|
|
||||||
let attn = CausalSelfAttention::load_npy(&vb / "attn", cache, config)?;
|
|
||||||
let mlp = Mlp::load_npy(&vb / "mlp")?;
|
|
||||||
let input_layernorm = RmsNorm::load_npy(&vb / "rms_1")?;
|
|
||||||
let post_attention_layernorm = RmsNorm::load_npy(&vb / "rms_2")?;
|
|
||||||
Ok(Self::new(
|
|
||||||
input_layernorm,
|
|
||||||
attn,
|
|
||||||
post_attention_layernorm,
|
|
||||||
mlp,
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Llama {
|
|
||||||
pub fn load_npy(
|
|
||||||
device: &Device,
|
|
||||||
filename: &str,
|
|
||||||
cache: &Cache,
|
|
||||||
config: &Config,
|
|
||||||
) -> anyhow::Result<Self> {
|
|
||||||
let weight_path = std::path::Path::new(filename);
|
|
||||||
let weights = if weight_path.exists() {
|
|
||||||
println!("loading weights from {weight_path:?}");
|
|
||||||
let start_load = std::time::Instant::now();
|
|
||||||
let tensors = Tensor::read_npz(weight_path)?;
|
|
||||||
println!("loaded weights in {:?}", start_load.elapsed());
|
|
||||||
let tensors: std::collections::HashMap<String, Tensor> = tensors.into_iter().collect();
|
|
||||||
tensors
|
|
||||||
} else {
|
|
||||||
anyhow::bail!("cannot find {weight_path:?}")
|
|
||||||
};
|
|
||||||
let vb = VarBuilder::new(device, weights);
|
|
||||||
|
|
||||||
let wte = Embedding::load_npy(&vb / "transformer" / "wte")?;
|
|
||||||
let lm_head = Linear::load_npy(&vb / "lm_head")?;
|
|
||||||
let norm = RmsNorm::load_npy(&vb / "transformer" / "ln_f")?;
|
|
||||||
let blocks: Vec<_> = (0..config.n_layer)
|
|
||||||
.map(|i| Block::load_npy(&vb / "transformer" / "h" / i, cache, config).unwrap())
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
Ok(Self::new(wte, blocks, norm, lm_head))
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,127 +0,0 @@
|
|||||||
use super::*;
|
|
||||||
use candle::{safetensors::SafeTensors, Device, Result, Tensor};
|
|
||||||
use std::path::PathBuf;
|
|
||||||
|
|
||||||
pub struct VarBuilder<'a> {
|
|
||||||
routing: HashMap<String, usize>,
|
|
||||||
safetensors: Vec<SafeTensors<'a>>,
|
|
||||||
device: Device,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'a> VarBuilder<'a> {
|
|
||||||
pub fn new(safetensors: Vec<SafeTensors<'a>>, device: Device) -> Self {
|
|
||||||
let mut routing = HashMap::new();
|
|
||||||
for (index, sf) in safetensors.iter().enumerate() {
|
|
||||||
for k in sf.names() {
|
|
||||||
routing.insert(k.to_string(), index);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Self {
|
|
||||||
safetensors,
|
|
||||||
device,
|
|
||||||
routing,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get(&self, tensor_name: &str) -> Result<Tensor> {
|
|
||||||
// Unwrap or 0 just to let the proper error flow.
|
|
||||||
let index = self.routing.get(tensor_name).unwrap_or(&0);
|
|
||||||
self.safetensors[*index]
|
|
||||||
.tensor(tensor_name, &self.device)?
|
|
||||||
.to_dtype(DTYPE)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Linear {
|
|
||||||
fn load(prefix: &str, vb: &VarBuilder) -> Result<Self> {
|
|
||||||
let weight = vb.get(&format!("{prefix}.weight"))?;
|
|
||||||
Ok(Self::new(weight))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn load_multi(prefixes: &[&str], vb: &VarBuilder) -> Result<Self> {
|
|
||||||
let weights: Vec<_> = prefixes
|
|
||||||
.iter()
|
|
||||||
.map(|p| vb.get(&format!("{p}.weight")).unwrap())
|
|
||||||
.collect();
|
|
||||||
let weight = Tensor::cat(&weights, 0)?;
|
|
||||||
Ok(Self::new(weight))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RmsNorm {
|
|
||||||
fn load(prefix: &str, vb: &VarBuilder) -> Result<Self> {
|
|
||||||
let scale = vb.get(&format!("{prefix}.weight"))?;
|
|
||||||
Ok(Self::new(scale))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl CausalSelfAttention {
|
|
||||||
fn load(prefix: &str, vb: &VarBuilder, cache: &Cache, config: &Config) -> Result<Self> {
|
|
||||||
let c_attn = Linear::load_multi(
|
|
||||||
&[
|
|
||||||
&format!("{prefix}.q_proj"),
|
|
||||||
&format!("{prefix}.k_proj"),
|
|
||||||
&format!("{prefix}.v_proj"),
|
|
||||||
],
|
|
||||||
vb,
|
|
||||||
)?;
|
|
||||||
let o_proj = Linear::load(&format!("{prefix}.o_proj"), vb)?;
|
|
||||||
Ok(Self::new(c_attn, o_proj, config.n_head, cache))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Mlp {
|
|
||||||
fn load(prefix: &str, vb: &VarBuilder) -> Result<Self> {
|
|
||||||
let c_fc1 = Linear::load(&format!("{prefix}.gate_proj"), vb)?;
|
|
||||||
let c_fc2 = Linear::load(&format!("{prefix}.up_proj"), vb)?;
|
|
||||||
let c_proj = Linear::load(&format!("{prefix}.down_proj"), vb)?;
|
|
||||||
Ok(Self::new(c_fc1, c_fc2, c_proj))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Block {
|
|
||||||
fn load(prefix: &str, vb: &VarBuilder, cache: &Cache, config: &Config) -> Result<Self> {
|
|
||||||
let attn = CausalSelfAttention::load(&format!("{prefix}.self_attn"), vb, cache, config)?;
|
|
||||||
let mlp = Mlp::load(&format!("{prefix}.mlp"), vb)?;
|
|
||||||
let input_layernorm = RmsNorm::load(&format!("{prefix}.input_layernorm"), vb)?;
|
|
||||||
let post_attention_layernorm =
|
|
||||||
RmsNorm::load(&format!("{prefix}.post_attention_layernorm"), vb)?;
|
|
||||||
Ok(Self::new(
|
|
||||||
input_layernorm,
|
|
||||||
attn,
|
|
||||||
post_attention_layernorm,
|
|
||||||
mlp,
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Llama {
|
|
||||||
pub fn load(
|
|
||||||
device: &Device,
|
|
||||||
filenames: &[PathBuf],
|
|
||||||
cache: &Cache,
|
|
||||||
config: &Config,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let handles: Vec<_> = filenames
|
|
||||||
.iter()
|
|
||||||
.map(|f| unsafe { candle::safetensors::MmapedFile::new(f) })
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
let tensors: Vec<_> = handles
|
|
||||||
.iter()
|
|
||||||
.map(|h| h.deserialize())
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
|
|
||||||
let vb = VarBuilder::new(tensors, device.clone());
|
|
||||||
|
|
||||||
let embedding = vb.get("model.embed_tokens.weight")?;
|
|
||||||
let wte = Embedding::new(embedding);
|
|
||||||
let lm_head = Linear::load("lm_head", &vb)?;
|
|
||||||
let norm = RmsNorm::load("model.norm", &vb)?;
|
|
||||||
let blocks: Vec<_> = (0..config.n_layer)
|
|
||||||
.map(|i| Block::load(&format!("model.layers.{i}"), &vb, cache, config).unwrap())
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
Ok(Self::new(wte, blocks, norm, lm_head))
|
|
||||||
}
|
|
||||||
}
|
|
Reference in New Issue
Block a user