mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Complete the mixformer implementation. (#930)
* Complete the mixformers implementation. * Tweak the attention. * Add the phi-1.5 example. * Improve the phi example. * Bugfix. * Get the phi example to work.
This commit is contained in:
23
candle-examples/examples/phi/README.md
Normal file
23
candle-examples/examples/phi/README.md
Normal file
@ -0,0 +1,23 @@
|
||||
# candle-starcoder: code generation model
|
||||
|
||||
[phi-1.5](https://huggingface.co/microsoft/phi-1_5).
|
||||
|
||||
## Running some example
|
||||
|
||||
```bash
|
||||
$ cargo run --example phi --release -- --prompt "def print_prime(n): "
|
||||
|
||||
def print_prime(n):
|
||||
print("Printing prime numbers")
|
||||
for i in range(2, n+1):
|
||||
if is_prime(i):
|
||||
print(i)
|
||||
|
||||
def is_prime(n):
|
||||
if n <= 1:
|
||||
return False
|
||||
for i in range(2, int(math.sqrt(n))+1):
|
||||
if n % i == 0:
|
||||
return False
|
||||
return True
|
||||
```
|
163
candle-examples/examples/phi/main.rs
Normal file
163
candle-examples/examples/phi/main.rs
Normal file
@ -0,0 +1,163 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as Model};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: Tokenizer,
|
||||
logits_processor: LogitsProcessor,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer,
|
||||
logits_processor,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
println!("starting the inference loop");
|
||||
print!("{prompt}");
|
||||
std::io::stdout().flush()?;
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let mut new_tokens = vec![];
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input)?;
|
||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
new_tokens.push(next_token);
|
||||
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
||||
print!("{token}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
println!(
|
||||
"{sample_len} tokens generated ({:.3} token/s)",
|
||||
sample_len as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, default_value_t = 100)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long, default_value = "microsoft/phi-1_5")]
|
||||
model_id: String,
|
||||
|
||||
#[arg(long, default_value = "refs/pr/18")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
weight_file: Option<String>,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
args.model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let tokenizer_filename = repo.get("tokenizer.json")?;
|
||||
let filenames = match args.weight_file {
|
||||
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
|
||||
None => ["model.safetensors"]
|
||||
.iter()
|
||||
.map(|f| repo.get(f))
|
||||
.collect::<std::result::Result<Vec<_>, _>>()?,
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let weights = filenames
|
||||
.iter()
|
||||
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? }))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let weights = weights
|
||||
.iter()
|
||||
.map(|f| Ok(f.deserialize()?))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let vb = VarBuilder::from_safetensors(weights, DType::F32, &device);
|
||||
let config = Config::v1_5();
|
||||
let model = Model::new(&config, vb)?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
@ -1,10 +1,11 @@
|
||||
#![allow(unused)]
|
||||
/// MixFormer model.
|
||||
/// https://huggingface.co/microsoft/phi-1_5
|
||||
/// https://arxiv.org/abs/2309.05463
|
||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::{Activation, VarBuilder};
|
||||
|
||||
const MAX_SEQ_LEN: usize = 4096;
|
||||
|
||||
// https://huggingface.co/microsoft/phi-1_5/blob/main/configuration_mixformer_sequential.py
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct Config {
|
||||
@ -21,8 +22,8 @@ pub struct Config {
|
||||
pad_vocab_size_multiple: usize,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
impl Config {
|
||||
pub fn v1() -> Self {
|
||||
Self {
|
||||
vocab_size: 50304,
|
||||
n_positions: 2048,
|
||||
@ -37,6 +38,22 @@ impl Default for Config {
|
||||
pad_vocab_size_multiple: 64,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn v1_5() -> Self {
|
||||
Self {
|
||||
vocab_size: 51200,
|
||||
n_positions: 2048,
|
||||
n_embd: 2048,
|
||||
n_layer: 24,
|
||||
n_inner: None,
|
||||
n_head: 32,
|
||||
rotary_dim: usize::min(32, 2048 / 32),
|
||||
activation_function: Activation::Gelu,
|
||||
layer_norm_epsilon: 1e-5,
|
||||
tie_word_embeddings: false,
|
||||
pad_vocab_size_multiple: 64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -58,7 +75,70 @@ impl Module for Embedding {
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct RotaryEmbedding {}
|
||||
struct RotaryEmbedding {
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(dim: usize, max_seq_len: usize, dev: &Device) -> Result<Self> {
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
})
|
||||
}
|
||||
|
||||
fn apply_rotary_emb_qkv(
|
||||
&self,
|
||||
qkv: &Tensor,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<(Tensor, Tensor, Tensor)> {
|
||||
let (_b_size, seqlen, three, _, _headdim) = qkv.dims5()?;
|
||||
if three != 3 {
|
||||
candle::bail!("unexpected shape for qkv {:?}", qkv.shape())
|
||||
}
|
||||
let (_rotary_seqlen, rotary_dim) = self.cos.dims2()?;
|
||||
let rotary_dim = rotary_dim * 2;
|
||||
let q_rot = qkv.i((.., .., 0, .., ..rotary_dim))?;
|
||||
let q_pass = qkv.i((.., .., 0, .., rotary_dim..))?;
|
||||
let k_rot = qkv.i((.., .., 1, .., ..rotary_dim))?;
|
||||
let k_pass = qkv.i((.., .., 1, .., rotary_dim..))?;
|
||||
let q12 = q_rot.chunk(2, D::Minus1)?;
|
||||
let k12 = k_rot.chunk(2, D::Minus1)?;
|
||||
let (q1, q2) = (&q12[0], &q12[1]);
|
||||
let (k1, k2) = (&k12[0], &k12[1]);
|
||||
let c = self.cos.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?;
|
||||
let s = self.sin.narrow(0, seqlen_offset, seqlen)?.unsqueeze(1)?;
|
||||
let q_rot = Tensor::cat(
|
||||
&[
|
||||
(q1.broadcast_mul(&c)? - q2.broadcast_mul(&s)?)?,
|
||||
(q1.broadcast_mul(&s)? + q2.broadcast_mul(&c)?)?,
|
||||
],
|
||||
D::Minus1,
|
||||
)?;
|
||||
let k_rot = Tensor::cat(
|
||||
&[
|
||||
(k1.broadcast_mul(&c)? - k2.broadcast_mul(&s)?)?,
|
||||
(k1.broadcast_mul(&s)? + k2.broadcast_mul(&c)?)?,
|
||||
],
|
||||
D::Minus1,
|
||||
)?;
|
||||
let q = Tensor::cat(&[&q_rot, &q_pass], D::Minus1)?;
|
||||
let k = Tensor::cat(&[&k_rot, &k_pass], D::Minus1)?;
|
||||
let v = qkv.i((.., .., 2))?;
|
||||
Ok((q, k, v))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
@ -87,18 +167,6 @@ impl Module for MLP {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct SelfAttention {
|
||||
causal: bool,
|
||||
softmax_scale: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CrossAttention {
|
||||
causal: bool,
|
||||
softmax_scale: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CausalLMHead {
|
||||
ln: candle_nn::LayerNorm,
|
||||
@ -126,7 +194,10 @@ impl Module for CausalLMHead {
|
||||
struct MHA {
|
||||
wqkv: candle_nn::Linear,
|
||||
out_proj: candle_nn::Linear,
|
||||
rotary_emb: RotaryEmbedding,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
head_dim: usize,
|
||||
softmax_scale: f64,
|
||||
}
|
||||
|
||||
impl MHA {
|
||||
@ -135,23 +206,59 @@ impl MHA {
|
||||
let op_size = cfg.n_embd;
|
||||
let wqkv = candle_nn::linear(cfg.n_embd, 3 * op_size, vb.pp("Wqkv"))?;
|
||||
let out_proj = candle_nn::linear(op_size, cfg.n_embd, vb.pp("out_proj"))?;
|
||||
let rotary_emb = RotaryEmbedding::new(cfg.rotary_dim, MAX_SEQ_LEN, vb.device())?;
|
||||
let softmax_scale = 1f64 / (head_dim as f64).sqrt();
|
||||
Ok(Self {
|
||||
wqkv,
|
||||
out_proj,
|
||||
head_dim,
|
||||
kv_cache: None,
|
||||
rotary_emb,
|
||||
softmax_scale,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MHA {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (b_size, seq_len, n_embd) = xs.dims3()?;
|
||||
fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (b_size, seq_len, _n_embd) = xs.dims3()?;
|
||||
let qkv = self
|
||||
.wqkv
|
||||
.forward(xs)?
|
||||
.reshape((b_size, seq_len, 3, (), self.head_dim))?;
|
||||
let context: Tensor = qkv; // TODO
|
||||
context.flatten_from(D::Minus2)?.apply(&self.out_proj)
|
||||
let seqlen_offset = match &self.kv_cache {
|
||||
None => 0,
|
||||
Some((prev_k, _)) => prev_k.dim(1)?,
|
||||
};
|
||||
// In the python implementation, a single tensor is returned with the third axis of size 3.
|
||||
let (q, k, v) = self.rotary_emb.apply_rotary_emb_qkv(&qkv, seqlen_offset)?;
|
||||
let (k, v) = match &self.kv_cache {
|
||||
None => (k, v),
|
||||
Some((prev_k, prev_v)) => {
|
||||
let k = Tensor::cat(&[prev_k, &k], 1)?;
|
||||
let v = Tensor::cat(&[prev_v, &v], 1)?;
|
||||
(k, v)
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((k.clone(), v.clone()));
|
||||
// scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
|
||||
let q = q.transpose(1, 2)?.flatten_to(1)?; // b*h, t, d
|
||||
let k = k.transpose(1, 2)?.flatten_to(1)?; // b*h, s, d
|
||||
let v = v.transpose(1, 2)?.flatten_to(1)?; // b*h, s, d
|
||||
let attn_weights = (q.matmul(&k.t()?)? * self.softmax_scale)?; // b*h, t, s
|
||||
|
||||
// TODO: Add the causal mask.
|
||||
// causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0, device=scores.device), 1)
|
||||
// scores = scores + causal_mask.to(dtype=scores.dtype)
|
||||
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
|
||||
|
||||
// output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
|
||||
// attn_weights: b*h,t,s, v: b*h,s,d
|
||||
let attn_output = attn_weights.matmul(&v)?;
|
||||
// b*h,t,d
|
||||
let attn_output = attn_output
|
||||
.reshape((b_size, (), seq_len, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.flatten_from(D::Minus2)?;
|
||||
attn_output.apply(&self.out_proj)
|
||||
}
|
||||
}
|
||||
|
||||
@ -169,10 +276,8 @@ impl ParallelBlock {
|
||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||
Ok(Self { ln, mixer, mlp })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ParallelBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = xs.apply(&self.ln)?;
|
||||
let attn_outputs = self.mixer.forward(&xs)?;
|
||||
@ -204,14 +309,13 @@ impl MixFormerSequentialForCausalLM {
|
||||
head,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MixFormerSequentialForCausalLM {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (_b_size, seq_len) = xs.dims2()?;
|
||||
let mut xs = xs.apply(&self.embedding)?;
|
||||
for block in self.blocks.iter() {
|
||||
for block in self.blocks.iter_mut() {
|
||||
xs = block.forward(&xs)?
|
||||
}
|
||||
xs.apply(&self.head)
|
||||
xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user