From 544018b6d0ffa0a4b0ac6c30de10ec2012765fcb Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 22 Feb 2024 10:22:03 +0100 Subject: [PATCH] Explicit caching in llama2.c. --- candle-examples/examples/llama2-c/main.rs | 30 ++++---- candle-examples/examples/llama2-c/training.rs | 11 +-- candle-transformers/src/models/llama2_c.rs | 76 +++++++++++-------- .../src/models/quantized_llama2_c.rs | 57 ++++++++------ 4 files changed, 99 insertions(+), 75 deletions(-) diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index 27ebc80f..1a82bf1f 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -19,7 +19,7 @@ use candle_transformers::generation::LogitsProcessor; use std::io::Write; use tokenizers::Tokenizer; -use model::{Config, Llama}; +use model::{Cache, Config, Llama}; use qmodel::QLlama; use weights::TransformerWeights; @@ -160,10 +160,10 @@ enum Model { } impl Model { - fn forward(&self, xs: &Tensor, pos: usize) -> anyhow::Result { + fn forward(&self, xs: &Tensor, pos: usize, cache: &mut Cache) -> anyhow::Result { match self { - Self::Llama(l) => Ok(l.forward(xs, pos)?), - Self::QLlama(l) => Ok(l.forward(xs, pos)?), + Self::Llama(l) => Ok(l.forward(xs, pos, cache)?), + Self::QLlama(l) => Ok(l.forward(xs, pos, cache)?), } } } @@ -188,8 +188,8 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> { let config = Config::from_reader(&mut file)?; let weights = TransformerWeights::from_reader(&mut file, &config, &device)?; let vb = weights.var_builder(&config, &device)?; - let cache = model::Cache::new(false, &config, vb.pp("rot"))?; - let model = Llama::load(vb, &cache, config)?; + let mut cache = Cache::new(false, &config, vb.pp("rot"))?; + let model = Llama::load(vb, config)?; let tokens = match &args.pretokenized_dir { None => { @@ -235,7 +235,7 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> { let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size); for inp_tgt in batch_iter { let (inp, tgt) = inp_tgt?; - let logits = model.forward(&inp, 0)?; + let logits = model.forward(&inp, 0, &mut cache)?; let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?; println!("{}", loss.to_vec0::()?); } @@ -261,7 +261,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { let is_safetensors = config_path .extension() .map_or(false, |v| v == "safetensors"); - let (model, config) = if is_gguf { + let (model, config, mut cache) = if is_gguf { let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?; let (_vocab_size, dim) = vb .get_no_shape("model.embed_tokens.weight")? @@ -298,15 +298,15 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { &device, ); let cache = model::Cache::new(true, &config, fake_vb)?; - let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?); - (model, config) + let model = Model::QLlama(QLlama::load(vb, config.clone())?); + (model, config, cache) } else if is_safetensors { let config = Config::tiny_15m(); let tensors = candle::safetensors::load(config_path, &device)?; let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device); let cache = model::Cache::new(true, &config, vb.pp("rot"))?; - let model = Model::Llama(Llama::load(vb, &cache, config.clone())?); - (model, config) + let model = Model::Llama(Llama::load(vb, config.clone())?); + (model, config, cache) } else { let mut file = std::fs::File::open(config_path)?; let config = Config::from_reader(&mut file)?; @@ -314,8 +314,8 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { let weights = TransformerWeights::from_reader(&mut file, &config, &device)?; let vb = weights.var_builder(&config, &device)?; let cache = model::Cache::new(true, &config, vb.pp("rot"))?; - let model = Model::Llama(Llama::load(vb, &cache, config.clone())?); - (model, config) + let model = Model::Llama(Llama::load(vb, config.clone())?); + (model, config, cache) }; println!("starting the inference loop"); @@ -338,7 +338,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { let context_size = if index > 0 { 1 } else { tokens.len() }; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; - let logits = model.forward(&input, index_pos)?; + let logits = model.forward(&input, index_pos, &mut cache)?; let logits = logits.i((0, logits.dim(1)? - 1))?; let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() { logits diff --git a/candle-examples/examples/llama2-c/training.rs b/candle-examples/examples/llama2-c/training.rs index b2aa0889..c83ca43f 100644 --- a/candle-examples/examples/llama2-c/training.rs +++ b/candle-examples/examples/llama2-c/training.rs @@ -8,6 +8,7 @@ fn valid_loss( model: &Llama, args: &crate::TrainingCmd, device: &Device, + cache: &mut Cache, ) -> Result { let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone()); let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size); @@ -15,7 +16,7 @@ fn valid_loss( let mut cnt = 0usize; for inp_tgt in batch_iter.take(50) { let (inp, tgt) = inp_tgt?; - let logits = model.forward(&inp, 0)?; + let logits = model.forward(&inp, 0, cache)?; let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?; sum_ce += loss.to_vec0::()? as f64; cnt += 1; @@ -37,8 +38,8 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> { let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone()); let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size); - let cache = Cache::new(false, &config, vb.pp("rot"))?; - let model = Llama::load(vb, &cache, config)?; + let mut cache = Cache::new(false, &config, vb.pp("rot"))?; + let model = Llama::load(vb, config)?; let params = candle_nn::ParamsAdamW { lr: args.learning_rate, ..Default::default() @@ -46,14 +47,14 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> { let mut opt = candle_nn::AdamW::new(varmap.all_vars(), params)?; for (batch_index, batch) in batch_iter.enumerate() { let (inp, tgt) = batch?; - let logits = model.forward(&inp, 0)?; + let logits = model.forward(&inp, 0, &mut cache)?; let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?; opt.backward_step(&loss)?; if batch_index > 0 && batch_index % 100 == 0 { // TODO: Add a way to deactivate the backprop graph tracking when computing the // validation loss. - let loss = valid_loss(&dataset, &model, args, &device)?; + let loss = valid_loss(&dataset, &model, args, &device, &mut cache)?; println!("{batch_index} {loss}"); } if batch_index > 0 && batch_index % 1000 == 0 { diff --git a/candle-transformers/src/models/llama2_c.rs b/candle-transformers/src/models/llama2_c.rs index 753770fb..7b4f120b 100644 --- a/candle-transformers/src/models/llama2_c.rs +++ b/candle-transformers/src/models/llama2_c.rs @@ -2,7 +2,6 @@ use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::linear_no_bias as linear; use candle_nn::{embedding, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder}; use std::collections::HashMap; -use std::sync::{Arc, Mutex}; #[derive(Debug, Clone)] pub struct Config { @@ -70,12 +69,11 @@ impl Config { } } -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct Cache { - masks: Arc>>, + masks: HashMap, pub use_kv_cache: bool, - #[allow(clippy::type_complexity)] - pub kvs: Arc>>>, + pub kvs: Vec>, pub cos: Tensor, pub sin: Tensor, device: Device, @@ -105,25 +103,24 @@ impl Cache { let cos = freq_cis_real.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?; let sin = freq_cis_imag.reshape((cfg.seq_len, cfg.head_size() / 2, 1))?; Ok(Self { - masks: Arc::new(Mutex::new(HashMap::new())), + masks: HashMap::new(), use_kv_cache, - kvs: Arc::new(Mutex::new(vec![None; cfg.n_layers])), + kvs: vec![None; cfg.n_layers], cos, sin, device: vb.device().clone(), }) } - pub fn mask(&self, t: usize) -> Result { - let mut masks = self.masks.lock().unwrap(); - if let Some(mask) = masks.get(&t) { + pub fn mask(&mut self, t: usize) -> Result { + if let Some(mask) = self.masks.get(&t) { Ok(mask.clone()) } else { let mask: Vec<_> = (0..t) .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) .collect(); let mask = Tensor::from_slice(&mask, (t, t), &self.device)?; - masks.insert(t, mask.clone()); + self.masks.insert(t, mask.clone()); Ok(mask) } } @@ -133,6 +130,7 @@ fn silu(xs: &Tensor) -> Result { xs / (xs.neg()?.exp()? + 1.0)? } +#[derive(Debug, Clone)] struct CausalSelfAttention { q_proj: Linear, k_proj: Linear, @@ -141,14 +139,13 @@ struct CausalSelfAttention { n_head: usize, n_key_value_head: usize, head_dim: usize, - cache: Cache, } impl CausalSelfAttention { - fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result { + fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result { let (b_sz, seq_len, h, n_embd) = x.dims4()?; - let cos = self.cache.cos.i(index_pos..index_pos + seq_len)?; - let sin = self.cache.sin.i(index_pos..index_pos + seq_len)?; + let cos = cache.cos.i(index_pos..index_pos + seq_len)?; + let sin = cache.sin.i(index_pos..index_pos + seq_len)?; let cos = cos.unsqueeze(1)?; let sin = sin.unsqueeze(1)?; let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?; @@ -162,7 +159,13 @@ impl CausalSelfAttention { Ok(rope) } - fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result { + fn forward( + &self, + x: &Tensor, + index_pos: usize, + block_idx: usize, + cache: &mut Cache, + ) -> Result { let (b_sz, seq_len, n_embd) = x.dims3()?; let q = self.q_proj.forward(x)?; let k = self.k_proj.forward(x)?; @@ -172,16 +175,15 @@ impl CausalSelfAttention { let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?; let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?; - let q = self.apply_rotary_emb(&q, index_pos)?; - let mut k = self.apply_rotary_emb(&k, index_pos)?; + let q = self.apply_rotary_emb(&q, index_pos, cache)?; + let mut k = self.apply_rotary_emb(&k, index_pos, cache)?; - if self.cache.use_kv_cache { - let mut cache = self.cache.kvs.lock().unwrap(); - if let Some((cache_k, cache_v)) = &cache[block_idx] { + if cache.use_kv_cache { + if let Some((cache_k, cache_v)) = &cache.kvs[block_idx] { k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?; v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?; } - cache[block_idx] = Some((k.clone(), v.clone())) + cache.kvs[block_idx] = Some((k.clone(), v.clone())) } let k = self.repeat_kv(k)?; @@ -192,7 +194,7 @@ impl CausalSelfAttention { let v = v.transpose(1, 2)?.contiguous()?; let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; - let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?; + let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?; let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; let att = candle_nn::ops::softmax(&att, D::Minus1)?; // Convert to contiguous as matmul doesn't support strided vs for now. @@ -216,7 +218,7 @@ impl CausalSelfAttention { } } - fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result { + fn load(vb: VarBuilder, cfg: &Config) -> Result { let size_in = cfg.dim; let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads; let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads; @@ -232,7 +234,6 @@ impl CausalSelfAttention { n_head: cfg.n_heads, n_key_value_head: cfg.n_kv_heads, head_dim: cfg.dim / cfg.n_heads, - cache: cache.clone(), }) } } @@ -244,6 +245,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result Ok(m) } +#[derive(Debug, Clone)] struct Mlp { c_fc1: Linear, c_fc2: Linear, @@ -274,6 +276,7 @@ impl Mlp { } } +#[derive(Debug, Clone)] struct Block { rms_1: RmsNorm, attn: CausalSelfAttention, @@ -291,17 +294,23 @@ impl Block { } } - fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result { + fn forward( + &self, + x: &Tensor, + index_pos: usize, + block_idx: usize, + cache: &mut Cache, + ) -> Result { let residual = x; let x = self.rms_1.forward(x)?; - let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?; + let x = (self.attn.forward(&x, index_pos, block_idx, cache)? + residual)?; let residual = &x; let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?; Ok(x) } - fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result { - let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?; + fn load(vb: VarBuilder, cfg: &Config) -> Result { + let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?; let mlp = Mlp::load(vb.pp("mlp"), cfg)?; let input_layernorm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?; let post_attention_layernorm = @@ -315,6 +324,7 @@ impl Block { } } +#[derive(Debug, Clone)] pub struct Llama { wte: Embedding, blocks: Vec, @@ -324,23 +334,23 @@ pub struct Llama { } impl Llama { - pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result { + pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result { let (_b_sz, _seq_len) = x.dims2()?; let mut x = self.wte.forward(x)?; for (block_idx, block) in self.blocks.iter().enumerate() { - x = block.forward(&x, index_pos, block_idx)?; + x = block.forward(&x, index_pos, block_idx, cache)?; } let x = self.ln_f.forward(&x)?; let logits = self.lm_head.forward(&x)?; logits.to_dtype(DType::F32) } - pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result { + pub fn load(vb: VarBuilder, cfg: Config) -> Result { let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?; let ln_f = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.n_layers) - .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, &cfg).unwrap()) + .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), &cfg).unwrap()) .collect(); Ok(Self { wte, diff --git a/candle-transformers/src/models/quantized_llama2_c.rs b/candle-transformers/src/models/quantized_llama2_c.rs index 68ebee0d..b43ca9ff 100644 --- a/candle-transformers/src/models/quantized_llama2_c.rs +++ b/candle-transformers/src/models/quantized_llama2_c.rs @@ -7,6 +7,7 @@ fn silu(xs: &Tensor) -> Result { xs / (xs.neg()?.exp()? + 1.0)? } +#[derive(Debug, Clone)] struct CausalSelfAttention { q_proj: Linear, k_proj: Linear, @@ -15,14 +16,13 @@ struct CausalSelfAttention { n_head: usize, n_key_value_head: usize, head_dim: usize, - cache: Cache, } impl CausalSelfAttention { - fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result { + fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize, cache: &Cache) -> Result { let (b_sz, seq_len, h, n_embd) = x.dims4()?; - let cos = self.cache.cos.i(index_pos..index_pos + seq_len)?; - let sin = self.cache.sin.i(index_pos..index_pos + seq_len)?; + let cos = cache.cos.i(index_pos..index_pos + seq_len)?; + let sin = cache.sin.i(index_pos..index_pos + seq_len)?; let cos = cos.unsqueeze(1)?; let sin = sin.unsqueeze(1)?; let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?; @@ -36,7 +36,13 @@ impl CausalSelfAttention { Ok(rope) } - fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result { + fn forward( + &self, + x: &Tensor, + index_pos: usize, + block_idx: usize, + cache: &mut Cache, + ) -> Result { let (b_sz, seq_len, n_embd) = x.dims3()?; let q = self.q_proj.forward(x)?; let k = self.k_proj.forward(x)?; @@ -46,16 +52,15 @@ impl CausalSelfAttention { let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?; let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?; - let q = self.apply_rotary_emb(&q, index_pos)?; - let mut k = self.apply_rotary_emb(&k, index_pos)?; + let q = self.apply_rotary_emb(&q, index_pos, cache)?; + let mut k = self.apply_rotary_emb(&k, index_pos, cache)?; - if self.cache.use_kv_cache { - let mut cache = self.cache.kvs.lock().unwrap(); - if let Some((cache_k, cache_v)) = &cache[block_idx] { + if cache.use_kv_cache { + if let Some((cache_k, cache_v)) = &cache.kvs[block_idx] { k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?; v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?; } - cache[block_idx] = Some((k.clone(), v.clone())) + cache.kvs[block_idx] = Some((k.clone(), v.clone())) } let k = self.repeat_kv(k)?; @@ -66,7 +71,7 @@ impl CausalSelfAttention { let v = v.transpose(1, 2)?.contiguous()?; let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; - let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?; + let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?; let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; let att = candle_nn::ops::softmax(&att, D::Minus1)?; // Convert to contiguous as matmul doesn't support strided vs for now. @@ -90,7 +95,7 @@ impl CausalSelfAttention { } } - fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result { + fn load(vb: VarBuilder, cfg: &Config) -> Result { let size_in = cfg.dim; let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads; let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads; @@ -106,7 +111,6 @@ impl CausalSelfAttention { n_head: cfg.n_heads, n_key_value_head: cfg.n_kv_heads, head_dim: cfg.dim / cfg.n_heads, - cache: cache.clone(), }) } } @@ -118,6 +122,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result Ok(m) } +#[derive(Debug, Clone)] struct Mlp { c_fc1: Linear, c_fc2: Linear, @@ -148,6 +153,7 @@ impl Mlp { } } +#[derive(Debug, Clone)] struct Block { rms_1: RmsNorm, attn: CausalSelfAttention, @@ -165,17 +171,23 @@ impl Block { } } - fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result { + fn forward( + &self, + x: &Tensor, + index_pos: usize, + block_idx: usize, + cache: &mut Cache, + ) -> Result { let residual = x; let x = self.rms_1.forward(x)?; - let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?; + let x = (self.attn.forward(&x, index_pos, block_idx, cache)? + residual)?; let residual = &x; let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?; Ok(x) } - fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result { - let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?; + fn load(vb: VarBuilder, cfg: &Config) -> Result { + let attn = CausalSelfAttention::load(vb.pp("self_attn"), cfg)?; let mlp = Mlp::load(vb.pp("mlp"), cfg)?; let input_layernorm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?; let post_attention_layernorm = @@ -189,6 +201,7 @@ impl Block { } } +#[derive(Debug, Clone)] pub struct QLlama { wte: Embedding, blocks: Vec, @@ -198,23 +211,23 @@ pub struct QLlama { } impl QLlama { - pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result { + pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result { let (_b_sz, _seq_len) = x.dims2()?; let mut x = self.wte.forward(x)?; for (block_idx, block) in self.blocks.iter().enumerate() { - x = block.forward(&x, index_pos, block_idx)?; + x = block.forward(&x, index_pos, block_idx, cache)?; } let x = self.ln_f.forward(&x)?; let logits = self.lm_head.forward(&x)?; logits.to_dtype(DType::F32) } - pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result { + pub fn load(vb: VarBuilder, cfg: Config) -> Result { let wte = Embedding::new(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?; let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?; let ln_f = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.n_layers) - .map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cache, &cfg).unwrap()) + .map(|i| Block::load(vb.pp(format!("model.layers.{i}")), &cfg).unwrap()) .collect(); Ok(Self { wte,