Complete the mixformer implementation. (#930)

* Complete the mixformers implementation.

* Tweak the attention.

* Add the phi-1.5 example.

* Improve the phi example.

* Bugfix.

* Get the phi example to work.
This commit is contained in:
Laurent Mazare
2023-09-22 20:03:16 +01:00
committed by GitHub
parent a46b1b4657
commit df6f5240ba
3 changed files with 321 additions and 31 deletions

View File

@ -0,0 +1,23 @@
# candle-starcoder: code generation model
[phi-1.5](https://huggingface.co/microsoft/phi-1_5).
## Running some example
```bash
$ cargo run --example phi --release -- --prompt "def print_prime(n): "
def print_prime(n):
print("Printing prime numbers")
for i in range(2, n+1):
if is_prime(i):
print(i)
def is_prime(n):
if n <= 1:
return False
for i in range(2, int(math.sqrt(n))+1):
if n % i == 0:
return False
return True
```

View File

@ -0,0 +1,163 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as Model};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
device: Device,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
}
impl TextGeneration {
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer,
logits_processor,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
println!("starting the inference loop");
print!("{prompt}");
std::io::stdout().flush()?;
let mut tokens = self
.tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut new_tokens = vec![];
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input)?;
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
new_tokens.push(next_token);
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
print!("{token}");
std::io::stdout().flush()?;
}
let dt = start_gen.elapsed();
println!(
"{sample_len} tokens generated ({:.3} token/s)",
sample_len as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, default_value_t = 100)]
sample_len: usize,
#[arg(long, default_value = "microsoft/phi-1_5")]
model_id: String,
#[arg(long, default_value = "refs/pr/18")]
revision: String,
#[arg(long)]
weight_file: Option<String>,
}
fn main() -> Result<()> {
let args = Args::parse();
let start = std::time::Instant::now();
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
args.model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = repo.get("tokenizer.json")?;
let filenames = match args.weight_file {
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
None => ["model.safetensors"]
.iter()
.map(|f| repo.get(f))
.collect::<std::result::Result<Vec<_>, _>>()?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let weights = filenames
.iter()
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? }))
.collect::<Result<Vec<_>>>()?;
let weights = weights
.iter()
.map(|f| Ok(f.deserialize()?))
.collect::<Result<Vec<_>>>()?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let vb = VarBuilder::from_safetensors(weights, DType::F32, &device);
let config = Config::v1_5();
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -1,10 +1,11 @@
#![allow(unused)]
/// MixFormer model. /// MixFormer model.
/// https://huggingface.co/microsoft/phi-1_5 /// https://huggingface.co/microsoft/phi-1_5
/// https://arxiv.org/abs/2309.05463 /// https://arxiv.org/abs/2309.05463
use candle::{DType, Device, Module, Result, Tensor, D}; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder}; use candle_nn::{Activation, VarBuilder};
const MAX_SEQ_LEN: usize = 4096;
// https://huggingface.co/microsoft/phi-1_5/blob/main/configuration_mixformer_sequential.py // https://huggingface.co/microsoft/phi-1_5/blob/main/configuration_mixformer_sequential.py
#[derive(Debug, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub struct Config { pub struct Config {
@ -21,8 +22,8 @@ pub struct Config {
pad_vocab_size_multiple: usize, pad_vocab_size_multiple: usize,
} }
impl Default for Config { impl Config {
fn default() -> Self { pub fn v1() -> Self {
Self { Self {
vocab_size: 50304, vocab_size: 50304,
n_positions: 2048, n_positions: 2048,
@ -37,6 +38,22 @@ impl Default for Config {
pad_vocab_size_multiple: 64, pad_vocab_size_multiple: 64,
} }
} }
pub fn v1_5() -> Self {
Self {
vocab_size: 51200,
n_positions: 2048,
n_embd: 2048,
n_layer: 24,
n_inner: None,
n_head: 32,
rotary_dim: usize::min(32, 2048 / 32),
activation_function: Activation::Gelu,
layer_norm_epsilon: 1e-5,
tie_word_embeddings: false,
pad_vocab_size_multiple: 64,
}
}
} }
#[derive(Debug)] #[derive(Debug)]
@ -58,7 +75,70 @@ impl Module for Embedding {
} }
#[derive(Debug)] #[derive(Debug)]
struct RotaryEmbedding {} struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}
impl RotaryEmbedding {
fn new(dim: usize, max_seq_len: usize, dev: &Device) -> Result<Self> {
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(DType::F32)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
})
}
fn apply_rotary_emb_qkv(
&self,
qkv: &Tensor,
seqlen_offset: usize,
) -> Result<(Tensor, Tensor, Tensor)> {
let (_b_size, seqlen, three, _, _headdim) = qkv.dims5()?;
if three != 3 {
candle::bail!("unexpected shape for qkv {:?}", qkv.shape())
}
let (_rotary_seqlen, rotary_dim) = self.cos.dims2()?;
let rotary_dim = rotary_dim * 2;
let q_rot = qkv.i((.., .., 0, .., ..rotary_dim))?;
let q_pass = qkv.i((.., .., 0, .., rotary_dim..))?;
let k_rot = qkv.i((.., .., 1, .., ..rotary_dim))?;
let k_pass = qkv.i((.., .., 1, .., rotary_dim..))?;
let q12 = q_rot.chunk(2, D::Minus1)?;
let k12 = k_rot.chunk(2, D::Minus1)?;
let (q1, q2) = (&q12[0], &q12[1]);
let (k1, k2) = (&k12[0], &k12[1]);
let c = self.cos.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?;
let s = self.sin.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?;
let q_rot = Tensor::cat(
&[
(q1.broadcast_mul(&c)? - q2.broadcast_mul(&s)?)?,
(q1.broadcast_mul(&s)? + q2.broadcast_mul(&c)?)?,
],
D::Minus1,
)?;
let k_rot = Tensor::cat(
&[
(k1.broadcast_mul(&c)? - k2.broadcast_mul(&s)?)?,
(k1.broadcast_mul(&s)? + k2.broadcast_mul(&c)?)?,
],
D::Minus1,
)?;
let q = Tensor::cat(&[&q_rot, &q_pass], D::Minus1)?;
let k = Tensor::cat(&[&k_rot, &k_pass], D::Minus1)?;
let v = qkv.i((.., .., 2))?;
Ok((q, k, v))
}
}
#[derive(Debug)] #[derive(Debug)]
#[allow(clippy::upper_case_acronyms)] #[allow(clippy::upper_case_acronyms)]
@ -87,18 +167,6 @@ impl Module for MLP {
} }
} }
#[derive(Debug)]
struct SelfAttention {
causal: bool,
softmax_scale: f64,
}
#[derive(Debug)]
struct CrossAttention {
causal: bool,
softmax_scale: f64,
}
#[derive(Debug)] #[derive(Debug)]
struct CausalLMHead { struct CausalLMHead {
ln: candle_nn::LayerNorm, ln: candle_nn::LayerNorm,
@ -126,7 +194,10 @@ impl Module for CausalLMHead {
struct MHA { struct MHA {
wqkv: candle_nn::Linear, wqkv: candle_nn::Linear,
out_proj: candle_nn::Linear, out_proj: candle_nn::Linear,
rotary_emb: RotaryEmbedding,
kv_cache: Option<(Tensor, Tensor)>,
head_dim: usize, head_dim: usize,
softmax_scale: f64,
} }
impl MHA { impl MHA {
@ -135,23 +206,59 @@ impl MHA {
let op_size = cfg.n_embd; let op_size = cfg.n_embd;
let wqkv = candle_nn::linear(cfg.n_embd, 3 * op_size, vb.pp("Wqkv"))?; let wqkv = candle_nn::linear(cfg.n_embd, 3 * op_size, vb.pp("Wqkv"))?;
let out_proj = candle_nn::linear(op_size, cfg.n_embd, vb.pp("out_proj"))?; let out_proj = candle_nn::linear(op_size, cfg.n_embd, vb.pp("out_proj"))?;
let rotary_emb = RotaryEmbedding::new(cfg.rotary_dim, MAX_SEQ_LEN, vb.device())?;
let softmax_scale = 1f64 / (head_dim as f64).sqrt();
Ok(Self { Ok(Self {
wqkv, wqkv,
out_proj, out_proj,
head_dim, head_dim,
kv_cache: None,
rotary_emb,
softmax_scale,
}) })
} }
}
impl Module for MHA { fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { let (b_size, seq_len, _n_embd) = xs.dims3()?;
let (b_size, seq_len, n_embd) = xs.dims3()?;
let qkv = self let qkv = self
.wqkv .wqkv
.forward(xs)? .forward(xs)?
.reshape((b_size, seq_len, 3, (), self.head_dim))?; .reshape((b_size, seq_len, 3, (), self.head_dim))?;
let context: Tensor = qkv; // TODO let seqlen_offset = match &self.kv_cache {
context.flatten_from(D::Minus2)?.apply(&self.out_proj) None => 0,
Some((prev_k, _)) => prev_k.dim(1)?,
};
// In the python implementation, a single tensor is returned with the third axis of size 3.
let (q, k, v) = self.rotary_emb.apply_rotary_emb_qkv(&qkv, seqlen_offset)?;
let (k, v) = match &self.kv_cache {
None => (k, v),
Some((prev_k, prev_v)) => {
let k = Tensor::cat(&[prev_k, &k], 1)?;
let v = Tensor::cat(&[prev_v, &v], 1)?;
(k, v)
}
};
self.kv_cache = Some((k.clone(), v.clone()));
// scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
let q = q.transpose(1, 2)?.flatten_to(1)?; // b*h, t, d
let k = k.transpose(1, 2)?.flatten_to(1)?; // b*h, s, d
let v = v.transpose(1, 2)?.flatten_to(1)?; // b*h, s, d
let attn_weights = (q.matmul(&k.t()?)? * self.softmax_scale)?; // b*h, t, s
// TODO: Add the causal mask.
// causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0, device=scores.device), 1)
// scores = scores + causal_mask.to(dtype=scores.dtype)
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
// output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
// attn_weights: b*h,t,s, v: b*h,s,d
let attn_output = attn_weights.matmul(&v)?;
// b*h,t,d
let attn_output = attn_output
.reshape((b_size, (), seq_len, self.head_dim))?
.transpose(1, 2)?
.flatten_from(D::Minus2)?;
attn_output.apply(&self.out_proj)
} }
} }
@ -169,10 +276,8 @@ impl ParallelBlock {
let mlp = MLP::new(cfg, vb.pp("mlp"))?; let mlp = MLP::new(cfg, vb.pp("mlp"))?;
Ok(Self { ln, mixer, mlp }) Ok(Self { ln, mixer, mlp })
} }
}
impl Module for ParallelBlock { fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let residual = xs; let residual = xs;
let xs = xs.apply(&self.ln)?; let xs = xs.apply(&self.ln)?;
let attn_outputs = self.mixer.forward(&xs)?; let attn_outputs = self.mixer.forward(&xs)?;
@ -204,14 +309,13 @@ impl MixFormerSequentialForCausalLM {
head, head,
}) })
} }
}
impl Module for MixFormerSequentialForCausalLM { pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { let (_b_size, seq_len) = xs.dims2()?;
let mut xs = xs.apply(&self.embedding)?; let mut xs = xs.apply(&self.embedding)?;
for block in self.blocks.iter() { for block in self.blocks.iter_mut() {
xs = block.forward(&xs)? xs = block.forward(&xs)?
} }
xs.apply(&self.head) xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1)
} }
} }