Add support for Llama 3.1 (#2359)

* Add Llama 3.1 rope

* Clippy

* Format

* Clippy

* Add support for multiple eos tokens:

* Untagged either

* Remove either dep and fix settings.json

* Make the max positional embeddings configurable
This commit is contained in:
Eric Buehler
2024-07-26 15:32:26 -04:00
committed by GitHub
parent ddafc61055
commit 0f5cbb08b3
24 changed files with 165 additions and 71 deletions

View File

@ -288,7 +288,7 @@ impl BeitVisionTransformer {
let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?;
let vb_b = vb.pp("blocks");
let blocks = (0..depth)
.map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
.map(|i| Block::new(vb_b.pp(i.to_string()), embed_dim, num_heads))
.collect::<Result<Vec<_>>>()?;
Ok(Self {
patch_embed,

View File

@ -249,7 +249,7 @@ impl ClipEncoder {
let vs = vs.pp("layers");
let mut layers: Vec<ClipEncoderLayer> = Vec::new();
for index in 0..c.num_hidden_layers() {
let layer = ClipEncoderLayer::new(vs.pp(&index.to_string()), c)?;
let layer = ClipEncoderLayer::new(vs.pp(index.to_string()), c)?;
layers.push(layer)
}
Ok(ClipEncoder { layers })

View File

@ -214,7 +214,7 @@ impl DinoVisionTransformer {
let norm = layer_norm(embed_dim, 1e-5, vb.pp("norm"))?;
let vb_b = vb.pp("blocks");
let blocks = (0..depth)
.map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
.map(|i| Block::new(vb_b.pp(i.to_string()), embed_dim, num_heads))
.collect::<Result<Vec<_>>>()?;
Ok(Self {
patch_embed,

View File

@ -212,7 +212,7 @@ impl DinoVisionTransformer {
let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?;
let vb_b = vb.pp("blocks");
let blocks = (0..depth)
.map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
.map(|i| Block::new(vb_b.pp(i.to_string()), embed_dim, num_heads))
.collect::<Result<Vec<_>>>()?;
Ok(Self {
patch_embed,

View File

@ -571,7 +571,7 @@ impl<'a> Layer<'a> {
}
fn next(&mut self) -> VarBuilder {
let vb = self.vb.pp(&self.cnt.to_string());
let vb = self.vb.pp(self.cnt.to_string());
self.cnt += 1;
vb
}

View File

@ -255,14 +255,7 @@ impl EVA2VisionTransformer {
let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?;
let vb_b = vb.pp("blocks");
let blocks = (0..depth)
.map(|i| {
Block::new(
vb_b.pp(&i.to_string()),
embed_dim,
num_heads,
&rot_pos_embed,
)
})
.map(|i| Block::new(vb_b.pp(i.to_string()), embed_dim, num_heads, &rot_pos_embed))
.collect::<Result<Vec<_>>>()?;
Ok(Self {
patch_embed,

View File

@ -1,9 +1,33 @@
use super::with_tracing::{linear_no_bias as linear, Linear, RmsNorm};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{embedding, Embedding, Module, VarBuilder};
use std::collections::HashMap;
use std::{collections::HashMap, f32::consts::PI};
pub const MAX_SEQ_LEN: usize = 4096;
pub const DEFAULT_MAX_SEQ_LEN: usize = 4096;
#[derive(Debug, Clone, serde::Deserialize, Default)]
pub enum Llama3RopeType {
#[serde(rename = "llama3")]
Llama3,
#[default]
#[serde(rename = "default")]
Default,
}
#[derive(Debug, Clone, serde::Deserialize, Default)]
pub struct Llama3RopeConfig {
pub factor: f32,
pub low_freq_factor: f32,
pub high_freq_factor: f32,
pub original_max_position_embeddings: usize,
pub rope_type: Llama3RopeType,
}
#[derive(Debug, Clone, serde::Deserialize)]
#[serde(untagged)]
pub enum LlamaEosToks {
Single(u32),
Multiple(Vec<u32>),
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct LlamaConfig {
@ -17,7 +41,9 @@ pub struct LlamaConfig {
#[serde(default = "default_rope")]
pub rope_theta: f32,
pub bos_token_id: Option<u32>,
pub eos_token_id: Option<u32>,
pub eos_token_id: Option<LlamaEosToks>,
pub rope_scaling: Option<Llama3RopeConfig>,
pub max_position_embeddings: usize,
}
impl LlamaConfig {
@ -44,6 +70,8 @@ impl LlamaConfig {
use_flash_attn,
bos_token_id: self.bos_token_id,
eos_token_id: self.eos_token_id,
rope_scaling: self.rope_scaling,
max_position_embeddings: self.max_position_embeddings,
}
}
}
@ -60,7 +88,9 @@ pub struct Config {
pub rms_norm_eps: f64,
pub rope_theta: f32,
pub bos_token_id: Option<u32>,
pub eos_token_id: Option<u32>,
pub eos_token_id: Option<LlamaEosToks>,
pub rope_scaling: Option<Llama3RopeConfig>,
pub max_position_embeddings: usize,
}
impl Config {
@ -77,6 +107,8 @@ impl Config {
rope_theta: 10_000.0,
bos_token_id: None,
eos_token_id: None,
rope_scaling: None,
max_position_embeddings: DEFAULT_MAX_SEQ_LEN,
}
}
@ -93,6 +125,8 @@ impl Config {
rope_theta: 10_000.0,
bos_token_id: None,
eos_token_id: None,
rope_scaling: None,
max_position_embeddings: DEFAULT_MAX_SEQ_LEN,
}
}
}
@ -107,18 +141,54 @@ pub struct Cache {
device: Device,
}
fn calculate_default_inv_freq(cfg: &Config) -> Vec<f32> {
let head_dim = cfg.hidden_size / cfg.num_attention_heads;
(0..head_dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
.collect()
}
impl Cache {
pub fn new(use_kv_cache: bool, dtype: DType, config: &Config, device: &Device) -> Result<Self> {
// precompute freqs_cis
let n_elem = config.hidden_size / config.num_attention_heads;
let theta: Vec<_> = (0..n_elem)
.step_by(2)
.map(|i| 1f32 / config.rope_theta.powf(i as f32 / n_elem as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), device)?;
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
let theta = match &config.rope_scaling {
None
| Some(Llama3RopeConfig {
rope_type: Llama3RopeType::Default,
..
}) => calculate_default_inv_freq(config),
Some(rope_scaling) => {
let low_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
/ rope_scaling.low_freq_factor;
let high_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
/ rope_scaling.high_freq_factor;
calculate_default_inv_freq(config)
.into_iter()
.map(|freq| {
let wavelen = 2. * PI / freq;
if wavelen < high_freq_wavelen {
freq
} else if wavelen > low_freq_wavelen {
freq / rope_scaling.factor
} else {
let smooth = (rope_scaling.original_max_position_embeddings as f32
/ wavelen
- rope_scaling.low_freq_factor)
/ (rope_scaling.high_freq_factor - rope_scaling.low_freq_factor);
(1. - smooth) * freq / rope_scaling.factor + smooth * freq
}
})
.collect::<Vec<_>>()
}
};
let theta = Tensor::new(theta, device)?;
let idx_theta = Tensor::arange(0, config.max_position_embeddings as u32, device)?
.to_dtype(DType::F32)?
.reshape((MAX_SEQ_LEN, 1))?
.reshape((config.max_position_embeddings, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
// This is different from the paper, see:
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
@ -160,6 +230,7 @@ struct CausalSelfAttention {
use_flash_attn: bool,
span: tracing::Span,
span_rot: tracing::Span,
max_position_embeddings: usize,
}
#[cfg(feature = "flash-attn")]
@ -220,15 +291,23 @@ impl CausalSelfAttention {
k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
let k_seq_len = k.dims()[1];
if k_seq_len > MAX_SEQ_LEN {
if k_seq_len > self.max_position_embeddings {
k = k
.narrow(D::Minus1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
.narrow(
D::Minus1,
k_seq_len - self.max_position_embeddings,
self.max_position_embeddings,
)?
.contiguous()?
}
let v_seq_len = v.dims()[1];
if v_seq_len > 2 * MAX_SEQ_LEN {
if v_seq_len > 2 * self.max_position_embeddings {
v = v
.narrow(D::Minus1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
.narrow(
D::Minus1,
v_seq_len - self.max_position_embeddings,
self.max_position_embeddings,
)?
.contiguous()?
}
}
@ -291,6 +370,7 @@ impl CausalSelfAttention {
use_flash_attn: cfg.use_flash_attn,
span,
span_rot,
max_position_embeddings: cfg.max_position_embeddings,
})
}
}

View File

@ -2,7 +2,7 @@ use std::collections::HashMap;
use crate::models::{
clip::{text_model::Activation, vision_model::ClipVisionConfig},
llama::Config,
llama::{Config, LlamaEosToks},
};
use serde::{Deserialize, Serialize};
@ -73,8 +73,10 @@ impl LLaVAConfig {
rms_norm_eps: self.rms_norm_eps as f64,
rope_theta: self.rope_theta,
bos_token_id: Some(self.bos_token_id as u32),
eos_token_id: Some(self.eos_token_id as u32),
eos_token_id: Some(LlamaEosToks::Single(self.eos_token_id as u32)),
use_flash_attn: false,
rope_scaling: None, // Assume we don't have LLaVA for Llama 3.1
max_position_embeddings: self.max_position_embeddings,
}
}
}

View File

@ -358,7 +358,7 @@ impl SpatialTransformer {
let vs_tb = vs.pp("transformer_blocks");
for index in 0..config.depth {
let tb = BasicTransformerBlock::new(
vs_tb.pp(&index.to_string()),
vs_tb.pp(index.to_string()),
inner_dim,
n_heads,
d_head,

View File

@ -322,7 +322,7 @@ impl ClipEncoder {
let vs = vs.pp("layers");
let mut layers: Vec<ClipEncoderLayer> = Vec::new();
for index in 0..c.num_hidden_layers {
let layer = ClipEncoderLayer::new(vs.pp(&index.to_string()), c)?;
let layer = ClipEncoderLayer::new(vs.pp(index.to_string()), c)?;
layers.push(layer)
}
Ok(ClipEncoder { layers })

View File

@ -161,7 +161,7 @@ impl UNet2DConditionModel {
transformer_layers_per_block,
};
let block = CrossAttnDownBlock2D::new(
vs_db.pp(&i.to_string()),
vs_db.pp(i.to_string()),
in_channels,
out_channels,
Some(time_embed_dim),
@ -171,7 +171,7 @@ impl UNet2DConditionModel {
Ok(UNetDownBlock::CrossAttn(block))
} else {
let block = DownBlock2D::new(
vs_db.pp(&i.to_string()),
vs_db.pp(i.to_string()),
in_channels,
out_channels,
Some(time_embed_dim),
@ -251,7 +251,7 @@ impl UNet2DConditionModel {
transformer_layers_per_block,
};
let block = CrossAttnUpBlock2D::new(
vs_ub.pp(&i.to_string()),
vs_ub.pp(i.to_string()),
in_channels,
prev_out_channels,
out_channels,
@ -262,7 +262,7 @@ impl UNet2DConditionModel {
Ok(UNetUpBlock::CrossAttn(block))
} else {
let block = UpBlock2D::new(
vs_ub.pp(&i.to_string()),
vs_ub.pp(i.to_string()),
in_channels,
prev_out_channels,
out_channels,

View File

@ -146,7 +146,7 @@ impl DownEncoderBlock2D {
(0..(config.num_layers))
.map(|i| {
let in_channels = if i == 0 { in_channels } else { out_channels };
ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg)
ResnetBlock2D::new(vs.pp(i.to_string()), in_channels, conv_cfg)
})
.collect::<Result<Vec<_>>>()?
};
@ -235,7 +235,7 @@ impl UpDecoderBlock2D {
(0..(config.num_layers))
.map(|i| {
let in_channels = if i == 0 { in_channels } else { out_channels };
ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg)
ResnetBlock2D::new(vs.pp(i.to_string()), in_channels, conv_cfg)
})
.collect::<Result<Vec<_>>>()?
};
@ -328,9 +328,9 @@ impl UNetMidBlock2D {
};
let mut attn_resnets = vec![];
for index in 0..config.num_layers {
let attn = AttentionBlock::new(vs_attns.pp(&index.to_string()), in_channels, attn_cfg)?;
let attn = AttentionBlock::new(vs_attns.pp(index.to_string()), in_channels, attn_cfg)?;
let resnet = ResnetBlock2D::new(
vs_resnets.pp(&(index + 1).to_string()),
vs_resnets.pp((index + 1).to_string()),
in_channels,
resnet_cfg,
)?;
@ -425,7 +425,7 @@ impl UNetMidBlock2DCrossAttn {
let mut attn_resnets = vec![];
for index in 0..config.num_layers {
let attn = SpatialTransformer::new(
vs_attns.pp(&index.to_string()),
vs_attns.pp(index.to_string()),
in_channels,
n_heads,
in_channels / n_heads,
@ -433,7 +433,7 @@ impl UNetMidBlock2DCrossAttn {
attn_cfg,
)?;
let resnet = ResnetBlock2D::new(
vs_resnets.pp(&(index + 1).to_string()),
vs_resnets.pp((index + 1).to_string()),
in_channels,
resnet_cfg,
)?;
@ -515,7 +515,7 @@ impl DownBlock2D {
let resnets = (0..config.num_layers)
.map(|i| {
let in_channels = if i == 0 { in_channels } else { out_channels };
ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg)
ResnetBlock2D::new(vs_resnets.pp(i.to_string()), in_channels, resnet_cfg)
})
.collect::<Result<Vec<_>>>()?;
let downsampler = if config.add_downsample {
@ -619,7 +619,7 @@ impl CrossAttnDownBlock2D {
let attentions = (0..config.downblock.num_layers)
.map(|i| {
SpatialTransformer::new(
vs_attn.pp(&i.to_string()),
vs_attn.pp(i.to_string()),
out_channels,
n_heads,
out_channels / n_heads,
@ -724,7 +724,7 @@ impl UpBlock2D {
out_channels
};
let in_channels = resnet_in_channels + res_skip_channels;
ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg)
ResnetBlock2D::new(vs_resnets.pp(i.to_string()), in_channels, resnet_cfg)
})
.collect::<Result<Vec<_>>>()?;
let upsampler = if config.add_upsample {
@ -826,7 +826,7 @@ impl CrossAttnUpBlock2D {
let attentions = (0..config.upblock.num_layers)
.map(|i| {
SpatialTransformer::new(
vs_attn.pp(&i.to_string()),
vs_attn.pp(i.to_string()),
out_channels,
n_heads,
out_channels / n_heads,

View File

@ -80,7 +80,7 @@ impl Encoder {
..Default::default()
};
let down_block = DownEncoderBlock2D::new(
vs_down_blocks.pp(&index.to_string()),
vs_down_blocks.pp(index.to_string()),
in_channels,
out_channels,
cfg,
@ -222,7 +222,7 @@ impl Decoder {
..Default::default()
};
let up_block = UpDecoderBlock2D::new(
vs_up_blocks.pp(&index.to_string()),
vs_up_blocks.pp(index.to_string()),
in_channels,
out_channels,
cfg,

View File

@ -601,7 +601,7 @@ impl T5Block {
None
};
let ff_i = if cross_attn.is_some() { 2 } else { 1 };
let ff = T5LayerFF::load(vb.pp(&ff_i.to_string()), cfg)?;
let ff = T5LayerFF::load(vb.pp(ff_i.to_string()), cfg)?;
Ok(Self {
self_attn,
cross_attn,