mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
@ -30,7 +30,7 @@ impl Tensor {
|
|||||||
let mut current_dim = 0;
|
let mut current_dim = 0;
|
||||||
for (i, indexer) in indexers.iter().enumerate() {
|
for (i, indexer) in indexers.iter().enumerate() {
|
||||||
x = match indexer {
|
x = match indexer {
|
||||||
TensorIndexer::Select(n) => x.get(*n)?,
|
TensorIndexer::Select(n) => x.narrow(i, *n, 1)?.squeeze(i)?,
|
||||||
TensorIndexer::Narrow(left_bound, right_bound) => {
|
TensorIndexer::Narrow(left_bound, right_bound) => {
|
||||||
let start = match left_bound {
|
let start = match left_bound {
|
||||||
Bound::Included(n) => *n,
|
Bound::Included(n) => *n,
|
||||||
|
@ -12,8 +12,6 @@
|
|||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
// TODO: This does not use a batch dimension. If adding it back, be cautious about the
|
|
||||||
// transposition operations.
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use rand::{distributions::Distribution, SeedableRng};
|
use rand::{distributions::Distribution, SeedableRng};
|
||||||
@ -200,13 +198,14 @@ fn main() -> Result<()> {
|
|||||||
tokens.len()
|
tokens.len()
|
||||||
};
|
};
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||||
let input = Tensor::new(ctxt, &device)?;
|
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||||
let freqs_cis = if cache.use_kv_cache {
|
let freqs_cis = if cache.use_kv_cache {
|
||||||
freqs_cis.narrow(1, index_pos, ctxt.len())?
|
freqs_cis.narrow(1, index_pos, ctxt.len())?
|
||||||
} else {
|
} else {
|
||||||
freqs_cis.clone()
|
freqs_cis.clone()
|
||||||
};
|
};
|
||||||
let logits = llama.forward(&input, &freqs_cis)?;
|
let logits = llama.forward(&input, &freqs_cis)?;
|
||||||
|
let logits = logits.squeeze(0)?;
|
||||||
index_pos += ctxt.len();
|
index_pos += ctxt.len();
|
||||||
|
|
||||||
let next_token = if let Some(temperature) = args.temperature {
|
let next_token = if let Some(temperature) = args.temperature {
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use candle::{DType, Device, Result, Tensor, D};
|
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||||
use candle_nn::{Linear, VarBuilder};
|
use candle_nn::{Embedding, Linear, VarBuilder};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
@ -71,23 +71,9 @@ fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
|||||||
Ok(Linear::new(weight, None))
|
Ok(Linear::new(weight, None))
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Embedding {
|
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
|
||||||
embeddings: Tensor,
|
let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?;
|
||||||
}
|
Ok(Embedding::new(embeddings, cfg.hidden_size))
|
||||||
|
|
||||||
impl Embedding {
|
|
||||||
fn new(embeddings: Tensor) -> Self {
|
|
||||||
Self { embeddings }
|
|
||||||
}
|
|
||||||
|
|
||||||
fn load(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?;
|
|
||||||
Ok(Self::new(embeddings))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
|
|
||||||
Tensor::embedding(indexes, &self.embeddings)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct RmsNorm {
|
struct RmsNorm {
|
||||||
@ -108,15 +94,15 @@ impl RmsNorm {
|
|||||||
let in_dtype = x.dtype();
|
let in_dtype = x.dtype();
|
||||||
// This is a no-op if x's dtype is already f32.
|
// This is a no-op if x's dtype is already f32.
|
||||||
let x = x.to_dtype(DType::F32)?;
|
let x = x.to_dtype(DType::F32)?;
|
||||||
let (seq_len, hidden_size) = x.shape().r2()?;
|
let (b_sz, seq_len, hidden_size) = x.shape().r3()?;
|
||||||
let norm_x = ((&x * &x)?.sum(&[1])? / hidden_size as f64)?;
|
let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
|
||||||
let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
|
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
|
||||||
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
|
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
|
||||||
let size = self.scale.shape().r1()?;
|
let size = self.scale.shape().r1()?;
|
||||||
let scale = self
|
let scale = self
|
||||||
.scale
|
.scale
|
||||||
.to_dtype(DType::F32)?
|
.to_dtype(DType::F32)?
|
||||||
.broadcast_as((seq_len, size))?;
|
.broadcast_as((b_sz, seq_len, size))?;
|
||||||
let x = (scale * x_normed)?;
|
let x = (scale * x_normed)?;
|
||||||
let x = x.to_dtype(in_dtype)?;
|
let x = x.to_dtype(in_dtype)?;
|
||||||
Ok(x)
|
Ok(x)
|
||||||
@ -143,8 +129,8 @@ impl CausalSelfAttention {
|
|||||||
fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||||
let mut dims = x.dims().to_vec();
|
let mut dims = x.dims().to_vec();
|
||||||
let fcis_dims = freqs_cis.dims();
|
let fcis_dims = freqs_cis.dims();
|
||||||
let freqs_cis = if dims[1] < fcis_dims[1] {
|
let freqs_cis = if dims[2] < fcis_dims[1] {
|
||||||
freqs_cis.narrow(1, 0, dims[1])?
|
freqs_cis.narrow(1, 0, dims[2])?
|
||||||
} else {
|
} else {
|
||||||
freqs_cis.clone()
|
freqs_cis.clone()
|
||||||
};
|
};
|
||||||
@ -169,35 +155,34 @@ impl CausalSelfAttention {
|
|||||||
|
|
||||||
fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor, freqs_cis: &Tensor, block_idx: usize) -> Result<Tensor> {
|
||||||
let x_dtype = x.dtype();
|
let x_dtype = x.dtype();
|
||||||
let (t, c) = x.shape().r2()?;
|
let (b_sz, seq_len, n_embd) = x.shape().r3()?;
|
||||||
let qkv = self.c_attn.forward(x)?;
|
let qkv = self.c_attn.forward(x)?;
|
||||||
let qkv = qkv.to_dtype(DType::F32)?;
|
let qkv = qkv.to_dtype(DType::F32)?;
|
||||||
let n_embd = c;
|
let q = qkv.narrow(D::Minus1, 0, n_embd)?;
|
||||||
let q = qkv.narrow(1, 0, n_embd)?;
|
let k = qkv.narrow(D::Minus1, n_embd, n_embd)?;
|
||||||
let k = qkv.narrow(1, n_embd, n_embd)?;
|
let v = qkv.narrow(D::Minus1, 2 * n_embd, n_embd)?;
|
||||||
let v = qkv.narrow(1, 2 * n_embd, n_embd)?;
|
let target_dim = [b_sz, seq_len, self.n_head, n_embd / self.n_head];
|
||||||
let target_dim = [t, self.n_head, c / self.n_head];
|
let k = k.reshape(target_dim.as_slice())?.transpose(1, 2)?;
|
||||||
let k = k.reshape(target_dim.as_slice())?.transpose(0, 1)?;
|
let q = q.reshape(target_dim.as_slice())?.transpose(1, 2)?;
|
||||||
let q = q.reshape(target_dim.as_slice())?.transpose(0, 1)?;
|
let mut v = v.reshape(target_dim.as_slice())?.transpose(1, 2)?;
|
||||||
let mut v = v.reshape(target_dim.as_slice())?.transpose(0, 1)?;
|
|
||||||
let q = self.apply_rotary_emb(&q, freqs_cis)?;
|
let q = self.apply_rotary_emb(&q, freqs_cis)?;
|
||||||
let mut k = self.apply_rotary_emb(&k, freqs_cis)?;
|
let mut k = self.apply_rotary_emb(&k, freqs_cis)?;
|
||||||
|
|
||||||
if self.cache.use_kv_cache {
|
if self.cache.use_kv_cache {
|
||||||
let mut cache = self.cache.kvs.lock().unwrap();
|
let mut cache = self.cache.kvs.lock().unwrap();
|
||||||
if let Some((cache_k, cache_v)) = &cache[block_idx] {
|
if let Some((cache_k, cache_v)) = &cache[block_idx] {
|
||||||
k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
|
k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
|
||||||
v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?;
|
v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
|
||||||
let k_seq_len = k.dims()[1];
|
let k_seq_len = k.dims()[1];
|
||||||
if k_seq_len > MAX_SEQ_LEN {
|
if k_seq_len > MAX_SEQ_LEN {
|
||||||
k = k
|
k = k
|
||||||
.narrow(1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
|
.narrow(D::Minus1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
|
||||||
.contiguous()?
|
.contiguous()?
|
||||||
}
|
}
|
||||||
let v_seq_len = v.dims()[1];
|
let v_seq_len = v.dims()[1];
|
||||||
if v_seq_len > 2 * MAX_SEQ_LEN {
|
if v_seq_len > 2 * MAX_SEQ_LEN {
|
||||||
v = v
|
v = v
|
||||||
.narrow(1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
|
.narrow(D::Minus1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
|
||||||
.contiguous()?
|
.contiguous()?
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -205,12 +190,12 @@ impl CausalSelfAttention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let att = (q.matmul(&k.t()?)? / (k.dim(D::Minus1)? as f64).sqrt())?;
|
let att = (q.matmul(&k.t()?)? / (k.dim(D::Minus1)? as f64).sqrt())?;
|
||||||
let mask = self.cache.mask(t)?.broadcast_as(att.shape())?;
|
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||||
let att = att.softmax(D::Minus1)?;
|
let att = att.softmax(D::Minus1)?;
|
||||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||||
let y = att.matmul(&v.contiguous()?)?;
|
let y = att.matmul(&v.contiguous()?)?;
|
||||||
let y = y.transpose(0, 1)?.reshape(&[t, c])?;
|
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
||||||
let y = y.to_dtype(x_dtype)?;
|
let y = y.to_dtype(x_dtype)?;
|
||||||
let y = self.c_proj.forward(&y)?;
|
let y = self.c_proj.forward(&y)?;
|
||||||
Ok(y)
|
Ok(y)
|
||||||
@ -336,23 +321,19 @@ impl Llama {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
pub fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||||
// TODO: Support for mini-batches? (i.e. r2)
|
let (_b_sz, seq_len) = x.shape().r2()?;
|
||||||
let t = x.shape().r1()?;
|
|
||||||
let mut x = self.wte.forward(x)?;
|
let mut x = self.wte.forward(x)?;
|
||||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||||
x = block.forward(&x, freqs_cis, block_idx)?;
|
x = block.forward(&x, freqs_cis, block_idx)?;
|
||||||
}
|
}
|
||||||
let x = self.ln_f.forward(&x)?;
|
let x = self.ln_f.forward(&x)?;
|
||||||
let x = x.narrow(0, t - 1, 1)?;
|
let x = x.i((.., seq_len - 1, ..))?;
|
||||||
let logits = self.lm_head.forward(&x)?;
|
let logits = self.lm_head.forward(&x)?;
|
||||||
let logits = logits.to_dtype(DType::F32)?;
|
logits.to_dtype(DType::F32)
|
||||||
let (b, vocab_size) = logits.shape().r2()?;
|
|
||||||
assert_eq!(b, 1);
|
|
||||||
logits.reshape(vocab_size)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||||
let wte = Embedding::load(cfg, vb.pp("model.embed_tokens"))?;
|
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
|
||||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||||
let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?;
|
let norm = RmsNorm::load(cfg.hidden_size, vb.pp("model.norm"))?;
|
||||||
let blocks: Vec<_> = (0..cfg.n_layer)
|
let blocks: Vec<_> = (0..cfg.n_layer)
|
||||||
|
Reference in New Issue
Block a user