Separate quantized phi-3 implementation. (#2157)

* Separate quantized phi-3 implementation.

* Integrate the quantized phi3 model.=

* Small fixes, get the generation to work properly.

* Keep the old llama implementation around.

* Change the default.
This commit is contained in:
Laurent Mazare
2024-05-04 10:14:57 +02:00
committed by GitHub
parent 59b18d974e
commit b13a82a438
7 changed files with 323 additions and 12 deletions

View File

@ -676,9 +676,6 @@ impl BackendStorage for MetalStorage {
}
}
if layout.is_contiguous() {
} else {
}
Ok(Self::new(buffer, device.clone(), el_count, dtype))
}

View File

@ -178,7 +178,7 @@ impl crate::CustomOp1 for ArgSort {
device.metal_device(),
&command_buffer,
kernels,
&name,
name,
nrows,
ncols,
ncols_pad,

View File

@ -13,8 +13,9 @@ use candle::Tensor;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_transformers::models::quantized_llama::ModelWeights as Phi3;
use candle_transformers::models::quantized_llama::ModelWeights as Phi3b;
use candle_transformers::models::quantized_phi::ModelWeights as Phi2;
use candle_transformers::models::quantized_phi3::ModelWeights as Phi3;
const DEFAULT_PROMPT: &str = "Write a function to count prime numbers up to N. ";
@ -24,6 +25,9 @@ enum Which {
Phi2,
#[value(name = "phi-3")]
Phi3,
/// Alternative implementation of phi-3, based on llama.
#[value(name = "phi-3b")]
Phi3b,
}
#[derive(Parser, Debug)]
@ -84,7 +88,7 @@ struct Args {
repeat_last_n: usize,
/// The model size to use.
#[arg(long, default_value = "phi-2")]
#[arg(long, default_value = "phi-3b")]
which: Which,
}
@ -96,7 +100,7 @@ impl Args {
let api = hf_hub::api::sync::Api::new()?;
let repo = match self.which {
Which::Phi2 => "microsoft/phi-2",
Which::Phi3 => "microsoft/Phi-3-mini-4k-instruct",
Which::Phi3 | Which::Phi3b => "microsoft/Phi-3-mini-4k-instruct",
};
let api = api.model(repo.to_string());
api.get("tokenizer.json")?
@ -112,6 +116,11 @@ impl Args {
let (repo, filename, revision) = match self.which {
Which::Phi2 => ("TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf", "main"),
Which::Phi3 => (
"microsoft/Phi-3-mini-4k-instruct-gguf",
"Phi-3-mini-4k-instruct-q4.gguf",
"main",
),
Which::Phi3b => (
"microsoft/Phi-3-mini-4k-instruct-gguf",
"Phi-3-mini-4k-instruct-q4.gguf",
"5eef2ce24766d31909c0b269fe90c817a8f263fb",
@ -145,6 +154,7 @@ fn format_size(size_in_bytes: usize) -> String {
enum Model {
Phi2(Phi2),
Phi3(Phi3),
Phi3b(Phi3b),
}
impl Model {
@ -152,6 +162,7 @@ impl Model {
match self {
Self::Phi2(m) => m.forward(xs, pos),
Self::Phi3(m) => m.forward(xs, pos),
Self::Phi3b(m) => m.forward(xs, pos),
}
}
}
@ -203,6 +214,7 @@ fn main() -> anyhow::Result<()> {
match args.which {
Which::Phi2 => Model::Phi2(Phi2::from_gguf(model, &mut file, &device)?),
Which::Phi3 => Model::Phi3(Phi3::from_gguf(model, &mut file, &device)?),
Which::Phi3b => Model::Phi3b(Phi3b::from_gguf(model, &mut file, &device)?),
}
};
println!("model built");

View File

@ -350,7 +350,7 @@ pub fn call_unary_contiguous_tiled(
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
let tile_size = 2;
let tiles = length.div_ceil(tile_size);
let tiles = (length + tile_size - 1) / tile_size;
encoder.set_compute_pipeline_state(&pipeline);

View File

@ -40,6 +40,7 @@ pub mod quantized_mixformer;
pub mod quantized_moondream;
pub mod quantized_mpt;
pub mod quantized_phi;
pub mod quantized_phi3;
pub mod quantized_recurrent_gemma;
pub mod quantized_rwkv_v5;
pub mod quantized_rwkv_v6;

View File

@ -24,19 +24,19 @@ pub struct Config {
}
impl Config {
fn head_dim(&self) -> usize {
pub fn head_dim(&self) -> usize {
self.hidden_size / self.num_attention_heads
}
}
#[derive(Debug, Clone)]
struct RotaryEmbedding {
pub struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
pub fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let dim = cfg.head_dim();
let max_seq_len = cfg.max_position_embeddings;
let inv_freq: Vec<_> = (0..dim)
@ -55,7 +55,7 @@ impl RotaryEmbedding {
})
}
fn apply_rotary_emb_qkv(
pub fn apply_rotary_emb_qkv(
&self,
q: &Tensor,
k: &Tensor,

View File

@ -0,0 +1,301 @@
use std::collections::HashMap;
use candle::quantized::gguf_file;
use candle::quantized::QTensor;
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::{Embedding, RmsNorm};
pub const MAX_SEQ_LEN: usize = 4096;
#[derive(Debug, Clone)]
struct QLinear {
inner: candle::quantized::QMatMul,
span: tracing::Span,
}
impl QLinear {
fn new<R: std::io::Read + std::io::Seek>(
ct: &gguf_file::Content,
r: &mut R,
name: &str,
device: &Device,
) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
let w = ct.tensor(r, &format!("{name}.weight"), device)?;
let inner = candle::quantized::QMatMul::from_qtensor(w)?;
Ok(Self { inner, span })
}
}
impl Module for QLinear {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(xs)
}
}
#[derive(Debug, Clone)]
struct Mlp {
ffn_up: QLinear,
ffn_down: QLinear,
i_size: usize,
}
impl Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let up_states = xs.apply(&self.ffn_up)?;
let gate = up_states.narrow(D::Minus1, 0, self.i_size)?;
let up_states = up_states.narrow(D::Minus1, self.i_size, self.i_size)?;
let up_states = (up_states * gate.silu()?)?;
up_states.apply(&self.ffn_down)
}
}
fn rms_norm(w: QTensor, eps: f64) -> Result<RmsNorm> {
let w = w.dequantize(&w.device())?;
let rms = RmsNorm::new(w, eps);
Ok(rms)
}
#[derive(Debug, Clone)]
struct LayerWeights {
attn_qkv: QLinear,
attn_output: QLinear,
attn_norm: RmsNorm,
ffn_norm: RmsNorm,
mlp: Mlp,
n_head: usize,
n_kv_head: usize,
head_dim: usize,
cos: Tensor,
sin: Tensor,
neg_inf: Tensor,
kv_cache: Option<(Tensor, Tensor)>,
span_attn: tracing::Span,
span_rot: tracing::Span,
}
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result<Tensor> {
let shape = mask.shape();
let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?;
Ok(m)
}
impl LayerWeights {
fn apply_rotary_emb(&self, xs: &Tensor, index_pos: usize) -> Result<Tensor> {
let _enter = self.span_rot.enter();
let (_b_sz, _h, seq_len, _n_embd) = xs.dims4()?;
let cos = self.cos.narrow(0, index_pos, seq_len)?;
let sin = self.sin.narrow(0, index_pos, seq_len)?;
candle_nn::rotary_emb::rope(&xs.contiguous()?, &cos, &sin)
}
fn forward_attn(
&mut self,
x: &Tensor,
mask: Option<&Tensor>,
index_pos: usize,
) -> Result<Tensor> {
let _enter = self.span_attn.enter();
let (b_sz, seq_len, n_embd) = x.dims3()?;
let qkv = self.attn_qkv.forward(x)?;
let query_pos = self.n_head * self.head_dim;
let q = qkv.narrow(D::Minus1, 0, query_pos)?;
let k = qkv.narrow(D::Minus1, query_pos, self.n_kv_head * self.head_dim)?;
let v = qkv.narrow(
D::Minus1,
query_pos + self.n_kv_head * self.head_dim,
self.n_kv_head * self.head_dim,
)?;
let q = q
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
.transpose(1, 2)?;
let k = k
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
.transpose(1, 2)?;
let v = v
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
.transpose(1, 2)?;
let q = self.apply_rotary_emb(&q, index_pos)?.contiguous()?;
let k = self.apply_rotary_emb(&k, index_pos)?;
let (k, v) = match &self.kv_cache {
None => (k.contiguous()?, v.contiguous()?),
Some((k_cache, v_cache)) => {
if index_pos == 0 {
(k.contiguous()?, v.contiguous()?)
} else {
let k = Tensor::cat(&[k_cache, &k], 2)?;
let v = Tensor::cat(&[v_cache, &v], 2)?;
(k.contiguous()?, v.contiguous()?)
}
}
};
self.kv_cache = Some((k.clone(), v.clone()));
let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let att = match mask {
None => att,
Some(mask) => {
let mask = mask.broadcast_as(att.shape())?;
masked_fill(&att, &mask, &self.neg_inf)?
}
};
let att = candle_nn::ops::softmax_last_dim(&att)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
let y = self.attn_output.forward(&y)?;
Ok(y)
}
}
#[derive(Debug, Clone)]
pub struct ModelWeights {
tok_embeddings: Embedding,
layers: Vec<LayerWeights>,
output_norm: RmsNorm,
output: QLinear,
masks: HashMap<usize, Tensor>,
span: tracing::Span,
span_output: tracing::Span,
}
fn precomput_freqs_cis(
head_dim: usize,
freq_base: f32,
device: &Device,
) -> Result<(Tensor, Tensor)> {
let theta: Vec<_> = (0..head_dim)
.step_by(2)
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), device)?;
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
.to_dtype(DType::F32)?
.reshape((MAX_SEQ_LEN, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
let cos = idx_theta.cos()?;
let sin = idx_theta.sin()?;
Ok((cos, sin))
}
impl ModelWeights {
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
ct: gguf_file::Content,
reader: &mut R,
device: &Device,
) -> Result<Self> {
let md_get = |s: &str| match ct.metadata.get(s) {
None => candle::bail!("cannot find {s} in metadata"),
Some(v) => Ok(v),
};
// Parameter extraction from metadata.
let head_count = md_get("phi3.attention.head_count")?.to_u32()? as usize;
let head_count_kv = md_get("phi3.attention.head_count_kv")?.to_u32()? as usize;
let block_count = md_get("phi3.block_count")?.to_u32()? as usize;
let embedding_length = md_get("phi3.embedding_length")?.to_u32()? as usize;
let i_size = md_get("phi3.feed_forward_length")?.to_u32()? as usize;
let rope_dim = md_get("phi3.rope.dimension_count")?.to_u32()? as usize;
let rms_eps = md_get("phi3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device)?;
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
let tok_embeddings = tok_embeddings.dequantize(device)?;
let output_norm = rms_norm(ct.tensor(reader, "output_norm.weight", device)?, rms_eps)?;
let output = QLinear::new(&ct, reader, "output", device)?;
let mut layers = Vec::with_capacity(block_count);
for layer_idx in 0..block_count {
let prefix = format!("blk.{layer_idx}");
let ffn_up = QLinear::new(&ct, reader, &format!("{prefix}.ffn_up"), device)?;
let ffn_down = QLinear::new(&ct, reader, &format!("{prefix}.ffn_down"), device)?;
let mlp = Mlp {
ffn_up,
ffn_down,
i_size,
};
let attn_norm = rms_norm(
ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?,
rms_eps,
)?;
let ffn_norm = rms_norm(
ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?,
rms_eps,
)?;
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
layers.push(LayerWeights {
attn_qkv: QLinear::new(&ct, reader, &format!("{prefix}.attn_qkv"), device)?,
attn_output: QLinear::new(&ct, reader, &format!("{prefix}.attn_output"), device)?,
attn_norm,
ffn_norm,
mlp,
n_head: head_count,
n_kv_head: head_count_kv,
head_dim: embedding_length / head_count,
cos: cos.clone(),
sin: sin.clone(),
neg_inf: neg_inf.clone(),
kv_cache: None,
span_attn,
span_rot,
})
}
let span = tracing::span!(tracing::Level::TRACE, "model");
let span_output = tracing::span!(tracing::Level::TRACE, "output");
Ok(Self {
tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
layers,
output_norm,
output,
masks: HashMap::new(),
span,
span_output,
})
}
fn mask(&mut self, t: usize, device: &Device) -> Result<Tensor> {
if let Some(mask) = self.masks.get(&t) {
Ok(mask.clone())
} else {
let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
.collect();
let mask = Tensor::from_slice(&mask, (t, t), device)?;
self.masks.insert(t, mask.clone());
Ok(mask)
}
}
pub fn forward(&mut self, xs: &Tensor, index_pos: usize) -> Result<Tensor> {
let (_b_sz, seq_len) = xs.dims2()?;
let mask = if seq_len == 1 {
None
} else {
Some(self.mask(seq_len, xs.device())?)
};
let _enter = self.span.enter();
let mut xs = self.tok_embeddings.forward(xs)?;
for layer in self.layers.iter_mut() {
let residual = &xs;
let ys = xs.apply(&layer.attn_norm)?;
let ys = layer.forward_attn(&ys, mask.as_ref(), index_pos)?;
let ys = (ys + residual)?;
let residual = &ys;
let ys = ys.apply(&layer.ffn_norm)?;
let ys = layer.mlp.forward(&ys)?;
xs = (ys + residual)?
}
let xs = xs.apply(&self.output_norm)?.i((.., seq_len - 1, ..))?;
let _enter = self.span_output.enter();
self.output.forward(&xs)
}
}