mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Support for MQA for llama v2. (#205)
* Support for MQA for llama v2. * More llama-v2. * Move the rotary embedding precomputation in the cache. * Add a v2 flag. * Use the hf model.
This commit is contained in:
@ -13,7 +13,7 @@ let c = a.matmul(&b)?;
|
|||||||
Check out our [examples](./candle-examples/examples/):
|
Check out our [examples](./candle-examples/examples/):
|
||||||
|
|
||||||
- [Whisper](./candle-examples/examples/whisper/)
|
- [Whisper](./candle-examples/examples/whisper/)
|
||||||
- [Llama](./candle-examples/examples/llama/)
|
- [Llama and Llama-v2](./candle-examples/examples/llama/)
|
||||||
- [Bert](./candle-examples/examples/bert/) (Useful for sentence embeddings)
|
- [Bert](./candle-examples/examples/bert/) (Useful for sentence embeddings)
|
||||||
- [Falcon](./candle-examples/examples/falcon/)
|
- [Falcon](./candle-examples/examples/falcon/)
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ extern crate intel_mkl_src;
|
|||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor, D};
|
use candle::{DType, Tensor};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
@ -76,23 +76,6 @@ Whate'er it bodes, henceforward will I bear
|
|||||||
Upon my target three fair-shining suns.
|
Upon my target three fair-shining suns.
|
||||||
";
|
";
|
||||||
|
|
||||||
fn precompute_freqs_cis(config: &Config, device: &Device) -> Result<Tensor> {
|
|
||||||
let n_elem = config.n_embd / config.n_head;
|
|
||||||
let theta: Vec<_> = (0..n_elem)
|
|
||||||
.step_by(2)
|
|
||||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
|
|
||||||
.collect();
|
|
||||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
|
||||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
|
||||||
.to_dtype(DType::F32)?
|
|
||||||
.reshape((MAX_SEQ_LEN, 1))?
|
|
||||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
|
||||||
let shape = [1, MAX_SEQ_LEN, n_elem / 2, 1];
|
|
||||||
let idx_theta_cos = idx_theta.cos()?.reshape(&shape)?;
|
|
||||||
let idx_theta_sin = idx_theta.sin()?.reshape(&shape)?;
|
|
||||||
Ok(Tensor::cat(&[&idx_theta_cos, &idx_theta_sin], D::Minus1)?)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
@ -127,6 +110,12 @@ struct Args {
|
|||||||
/// Use f32 computations rather than f16.
|
/// Use f32 computations rather than f16.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
use_f32: bool,
|
use_f32: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
model_id: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
v2: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
@ -136,7 +125,7 @@ fn main() -> Result<()> {
|
|||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let config = Config::config_7b();
|
let config = Config::config_7b();
|
||||||
let cache = model::Cache::new(!args.no_kv_cache, &config, &device);
|
let cache = model::Cache::new(!args.no_kv_cache, &config, &device)?;
|
||||||
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
|
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
|
||||||
let (llama, tokenizer_filename) = match args.npy {
|
let (llama, tokenizer_filename) = match args.npy {
|
||||||
Some(filename) => {
|
Some(filename) => {
|
||||||
@ -146,8 +135,15 @@ fn main() -> Result<()> {
|
|||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
|
let model_id = args.model_id.unwrap_or_else(|| {
|
||||||
println!("loading the model weights");
|
if args.v2 {
|
||||||
|
"meta-llama/Llama-2-7b-hf".to_string()
|
||||||
|
} else {
|
||||||
|
"Narsil/amall-7b".to_string()
|
||||||
|
}
|
||||||
|
});
|
||||||
|
println!("loading the model weights from {model_id}");
|
||||||
|
let repo = Repo::new(model_id, RepoType::Model);
|
||||||
let tokenizer_filename = api.get(&repo, "tokenizer.json")?;
|
let tokenizer_filename = api.get(&repo, "tokenizer.json")?;
|
||||||
let mut filenames = vec![];
|
let mut filenames = vec![];
|
||||||
for rfilename in [
|
for rfilename in [
|
||||||
@ -180,8 +176,6 @@ fn main() -> Result<()> {
|
|||||||
.get_ids()
|
.get_ids()
|
||||||
.to_vec();
|
.to_vec();
|
||||||
|
|
||||||
println!("pre-computing the positional embeddings");
|
|
||||||
let freqs_cis = precompute_freqs_cis(&config, &device)?;
|
|
||||||
println!("starting the inference loop");
|
println!("starting the inference loop");
|
||||||
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
|
let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
|
||||||
let mut new_tokens = vec![];
|
let mut new_tokens = vec![];
|
||||||
@ -196,12 +190,7 @@ fn main() -> Result<()> {
|
|||||||
};
|
};
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||||
let freqs_cis = if cache.use_kv_cache {
|
let logits = llama.forward(&input, index_pos)?;
|
||||||
freqs_cis.narrow(1, index_pos, ctxt.len())?
|
|
||||||
} else {
|
|
||||||
freqs_cis.clone()
|
|
||||||
};
|
|
||||||
let logits = llama.forward(&input, &freqs_cis)?;
|
|
||||||
let logits = logits.squeeze(0)?;
|
let logits = logits.squeeze(0)?;
|
||||||
index_pos += ctxt.len();
|
index_pos += ctxt.len();
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@ pub struct Config {
|
|||||||
pub n_layer: usize,
|
pub n_layer: usize,
|
||||||
pub n_head: usize,
|
pub n_head: usize,
|
||||||
pub n_embd: usize,
|
pub n_embd: usize,
|
||||||
|
pub n_key_value_head: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
@ -23,6 +24,7 @@ impl Config {
|
|||||||
n_layer: 32,
|
n_layer: 32,
|
||||||
n_head: 32,
|
n_head: 32,
|
||||||
n_embd: 4096,
|
n_embd: 4096,
|
||||||
|
n_key_value_head: 32,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -33,17 +35,37 @@ pub struct Cache {
|
|||||||
pub use_kv_cache: bool,
|
pub use_kv_cache: bool,
|
||||||
#[allow(clippy::type_complexity)]
|
#[allow(clippy::type_complexity)]
|
||||||
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
|
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
|
||||||
|
cos: Tensor,
|
||||||
|
sin: Tensor,
|
||||||
device: Device,
|
device: Device,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Cache {
|
impl Cache {
|
||||||
pub fn new(use_kv_cache: bool, config: &Config, device: &Device) -> Self {
|
pub fn new(use_kv_cache: bool, config: &Config, device: &Device) -> Result<Self> {
|
||||||
Self {
|
// precompute freqs_cis
|
||||||
|
let n_elem = config.n_embd / config.n_head;
|
||||||
|
let theta: Vec<_> = (0..n_elem)
|
||||||
|
.step_by(2)
|
||||||
|
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
|
||||||
|
.collect();
|
||||||
|
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||||
|
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.reshape((MAX_SEQ_LEN, 1))?
|
||||||
|
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||||
|
// This is different from the paper, see:
|
||||||
|
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
|
||||||
|
let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
|
||||||
|
let cos = idx_theta.cos()?;
|
||||||
|
let sin = idx_theta.sin()?;
|
||||||
|
Ok(Self {
|
||||||
masks: Arc::new(Mutex::new(HashMap::new())),
|
masks: Arc::new(Mutex::new(HashMap::new())),
|
||||||
use_kv_cache,
|
use_kv_cache,
|
||||||
kvs: Arc::new(Mutex::new(vec![None; config.n_layer])),
|
kvs: Arc::new(Mutex::new(vec![None; config.n_layer])),
|
||||||
device: device.clone(),
|
device: device.clone(),
|
||||||
}
|
cos,
|
||||||
|
sin,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn mask(&self, t: usize) -> Result<Tensor> {
|
fn mask(&self, t: usize) -> Result<Tensor> {
|
||||||
@ -97,7 +119,7 @@ impl RmsNorm {
|
|||||||
let (b_sz, seq_len, hidden_size) = x.shape().r3()?;
|
let (b_sz, seq_len, hidden_size) = x.shape().r3()?;
|
||||||
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
||||||
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
|
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
|
||||||
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
|
let x_normed = (x / (norm_x + 1e-6)?.sqrt()?)?;
|
||||||
let size = self.scale.shape().r1()?;
|
let size = self.scale.shape().r1()?;
|
||||||
let scale = self
|
let scale = self
|
||||||
.scale
|
.scale
|
||||||
@ -110,63 +132,52 @@ impl RmsNorm {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct CausalSelfAttention {
|
struct CausalSelfAttention {
|
||||||
c_attn: Linear,
|
q_proj: Linear,
|
||||||
c_proj: Linear,
|
k_proj: Linear,
|
||||||
|
v_proj: Linear,
|
||||||
|
o_proj: Linear,
|
||||||
n_head: usize,
|
n_head: usize,
|
||||||
|
n_key_value_head: usize,
|
||||||
|
head_dim: usize,
|
||||||
cache: Cache,
|
cache: Cache,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CausalSelfAttention {
|
impl CausalSelfAttention {
|
||||||
fn new(c_attn: Linear, c_proj: Linear, n_head: usize, cache: &Cache) -> Self {
|
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||||
Self {
|
let (b_sz, _, seq_len, n_embd) = x.shape().r4()?;
|
||||||
c_attn,
|
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
|
||||||
c_proj,
|
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
|
||||||
n_head,
|
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd))?;
|
||||||
cache: cache.clone(),
|
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd))?;
|
||||||
}
|
let x1 = x.narrow(D::Minus1, 0, n_embd / 2)?;
|
||||||
}
|
let x2 = x.narrow(D::Minus1, n_embd / 2, n_embd / 2)?;
|
||||||
|
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
|
||||||
fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?;
|
||||||
let mut dims = x.dims().to_vec();
|
|
||||||
let fcis_dims = freqs_cis.dims();
|
|
||||||
let freqs_cis = if dims[2] < fcis_dims[1] {
|
|
||||||
freqs_cis.narrow(1, 0, dims[2])?
|
|
||||||
} else {
|
|
||||||
freqs_cis.clone()
|
|
||||||
};
|
|
||||||
let v = dims.pop().unwrap();
|
|
||||||
dims.push(v / 2);
|
|
||||||
dims.push(2);
|
|
||||||
let x = x.reshape(dims)?;
|
|
||||||
let re_x = x.narrow(D::Minus1, 0, 1)?;
|
|
||||||
let im_x = x.narrow(D::Minus1, 1, 1)?;
|
|
||||||
let re_f = freqs_cis
|
|
||||||
.narrow(D::Minus1, 0, 1)?
|
|
||||||
.broadcast_as(re_x.shape())?;
|
|
||||||
let im_f = freqs_cis
|
|
||||||
.narrow(D::Minus1, 1, 1)?
|
|
||||||
.broadcast_as(im_x.shape())?;
|
|
||||||
let re = ((&re_x * &re_f)? - (&im_x * &im_f)?)?;
|
|
||||||
let im = ((&re_x * &im_f)? + (&im_x * &re_f)?)?;
|
|
||||||
let rope = Tensor::cat(&[&re, &im], D::Minus1)?;
|
|
||||||
let rope = rope.flatten_from(D::Minus2)?;
|
|
||||||
Ok(rope)
|
Ok(rope)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||||
let x_dtype = x.dtype();
|
let x_dtype = x.dtype();
|
||||||
let (b_sz, seq_len, n_embd) = x.shape().r3()?;
|
let (b_sz, seq_len, n_embd) = x.shape().r3()?;
|
||||||
let qkv = self.c_attn.forward(x)?;
|
let q = self.q_proj.forward(x)?;
|
||||||
let qkv = qkv.to_dtype(DType::F32)?;
|
let k = self.k_proj.forward(x)?;
|
||||||
let q = qkv.narrow(D::Minus1, 0, n_embd)?;
|
let v = self.v_proj.forward(x)?;
|
||||||
let k = qkv.narrow(D::Minus1, n_embd, n_embd)?;
|
|
||||||
let v = qkv.narrow(D::Minus1, 2 * n_embd, n_embd)?;
|
let q = q
|
||||||
let target_dim = [b_sz, seq_len, self.n_head, n_embd / self.n_head];
|
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
|
||||||
let k = k.reshape(target_dim.as_slice())?.transpose(1, 2)?;
|
.transpose(1, 2)?
|
||||||
let q = q.reshape(target_dim.as_slice())?.transpose(1, 2)?;
|
.to_dtype(DType::F32)?;
|
||||||
let mut v = v.reshape(target_dim.as_slice())?.transpose(1, 2)?;
|
let k = k
|
||||||
let q = self.apply_rotary_emb(&q, freqs_cis)?;
|
.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
|
||||||
let mut k = self.apply_rotary_emb(&k, freqs_cis)?;
|
.transpose(1, 2)?
|
||||||
|
.to_dtype(DType::F32)?;
|
||||||
|
let mut v = v
|
||||||
|
.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.to_dtype(DType::F32)?;
|
||||||
|
|
||||||
|
let q = self.apply_rotary_emb(&q, index_pos)?;
|
||||||
|
let mut k = self.apply_rotary_emb(&k, index_pos)?;
|
||||||
|
|
||||||
if self.cache.use_kv_cache {
|
if self.cache.use_kv_cache {
|
||||||
let mut cache = self.cache.kvs.lock().unwrap();
|
let mut cache = self.cache.kvs.lock().unwrap();
|
||||||
@ -189,7 +200,9 @@ impl CausalSelfAttention {
|
|||||||
cache[block_idx] = Some((k.clone(), v.clone()))
|
cache[block_idx] = Some((k.clone(), v.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
let att = (q.matmul(&k.t()?)? / (k.dim(D::Minus1)? as f64).sqrt())?;
|
let k = self.repeat_kv(k)?;
|
||||||
|
let v = self.repeat_kv(v)?;
|
||||||
|
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||||
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||||
let att = att.softmax(D::Minus1)?;
|
let att = att.softmax(D::Minus1)?;
|
||||||
@ -197,31 +210,42 @@ impl CausalSelfAttention {
|
|||||||
let y = att.matmul(&v.contiguous()?)?;
|
let y = att.matmul(&v.contiguous()?)?;
|
||||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
||||||
let y = y.to_dtype(x_dtype)?;
|
let y = y.to_dtype(x_dtype)?;
|
||||||
let y = self.c_proj.forward(&y)?;
|
let y = self.o_proj.forward(&y)?;
|
||||||
Ok(y)
|
Ok(y)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
|
||||||
|
let n_rep = self.n_head / self.n_key_value_head;
|
||||||
|
if n_rep == 1 {
|
||||||
|
Ok(x)
|
||||||
|
} else {
|
||||||
|
let (b_sz, n_kv_head, seq_len, head_dim) = x.shape().r4()?;
|
||||||
|
let x = x
|
||||||
|
.unsqueeze(2)?
|
||||||
|
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
|
||||||
|
.reshape((b_sz, n_kv_head, n_rep, seq_len, head_dim))?;
|
||||||
|
Ok(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||||
let size_in = cfg.hidden_size;
|
let size_in = cfg.hidden_size;
|
||||||
let size = (cfg.hidden_size / cfg.n_head) * cfg.n_head;
|
let size_q = (cfg.hidden_size / cfg.n_head) * cfg.n_head;
|
||||||
let q_proj = vb.get((size_in, size), "q_proj.weight")?;
|
let size_kv = (cfg.hidden_size / cfg.n_head) * cfg.n_key_value_head;
|
||||||
let k_proj = vb.get((size_in, size), "k_proj.weight")?;
|
let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?;
|
||||||
let v_proj = vb.get((size_in, size), "v_proj.weight")?;
|
let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?;
|
||||||
// Invert the transformation from:
|
let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?;
|
||||||
// https://github.com/huggingface/transformers/blob/2642d8d04b14c18199ebe7b35f976da02df61752/src/transformers/models/llama/convert_llama_weights_to_hf.py#L101
|
let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?;
|
||||||
let n_head = cfg.n_head;
|
Ok(Self {
|
||||||
let q_proj = q_proj
|
q_proj,
|
||||||
.reshape((n_head, 2, size / n_head / 2, size_in))?
|
k_proj,
|
||||||
.transpose(1, 2)?
|
v_proj,
|
||||||
.reshape((size_in, size))?;
|
o_proj,
|
||||||
let k_proj = k_proj
|
n_head: cfg.n_head,
|
||||||
.reshape((n_head, 2, size / n_head / 2, size_in))?
|
n_key_value_head: cfg.n_key_value_head,
|
||||||
.transpose(1, 2)?
|
head_dim: cfg.hidden_size / cfg.n_head,
|
||||||
.reshape((size_in, size))?;
|
cache: cache.clone(),
|
||||||
let attn_weight = Tensor::cat(&[q_proj, k_proj, v_proj], 0)?;
|
})
|
||||||
let c_attn = Linear::new(attn_weight, None);
|
|
||||||
let o_proj = linear(size, size_in, vb.pp("o_proj"))?;
|
|
||||||
Ok(Self::new(c_attn, o_proj, cfg.n_head, cache))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -279,12 +303,12 @@ impl Block {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||||
let x = (self
|
let residual = x;
|
||||||
.attn
|
let x = self.rms_1.forward(x)?;
|
||||||
.forward(&self.rms_1.forward(x)?, freqs_cis, block_idx)?
|
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
|
||||||
+ x)?;
|
let residual = &x;
|
||||||
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + x)?;
|
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
|
||||||
Ok(x)
|
Ok(x)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -320,11 +344,11 @@ impl Llama {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||||
let (_b_sz, seq_len) = x.shape().r2()?;
|
let (_b_sz, seq_len) = x.shape().r2()?;
|
||||||
let mut x = self.wte.forward(x)?;
|
let mut x = self.wte.forward(x)?;
|
||||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||||
x = block.forward(&x, freqs_cis, block_idx)?;
|
x = block.forward(&x, index_pos, block_idx)?;
|
||||||
}
|
}
|
||||||
let x = self.ln_f.forward(&x)?;
|
let x = self.ln_f.forward(&x)?;
|
||||||
let x = x.i((.., seq_len - 1, ..))?;
|
let x = x.i((.., seq_len - 1, ..))?;
|
||||||
|
Reference in New Issue
Block a user