use std::collections::HashMap; use candle::{bail, Context, DType, Device, Module, Result, Tensor, D}; use candle_nn::{ conv1d, embedding, layer_norm, Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder, }; use serde::{Deserialize, Deserializer}; pub const DTYPE: DType = DType::F32; // NOTE: HiddenAct and HiddenActLayer are both direct copies from bert.rs. #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] #[serde(rename_all = "lowercase")] pub enum HiddenAct { Gelu, GeluApproximate, Relu, } pub struct HiddenActLayer { act: HiddenAct, span: tracing::Span, } impl HiddenActLayer { fn new(act: HiddenAct) -> Self { let span = tracing::span!(tracing::Level::TRACE, "hidden-act"); Self { act, span } } fn forward(&self, xs: &Tensor) -> Result { let _enter = self.span.enter(); match self.act { // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213 HiddenAct::Gelu => xs.gelu_erf(), HiddenAct::GeluApproximate => xs.gelu(), HiddenAct::Relu => xs.relu(), } } } #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)] #[serde(rename_all = "lowercase")] enum PositionEmbeddingType { #[default] Absolute, } pub type Id2Label = HashMap; pub type Label2Id = HashMap; #[derive(Debug, Clone, PartialEq, Deserialize)] pub struct Config { pub vocab_size: usize, pub hidden_size: usize, pub num_hidden_layers: usize, pub num_attention_heads: usize, pub intermediate_size: usize, pub hidden_act: HiddenAct, pub hidden_dropout_prob: f64, pub attention_probs_dropout_prob: f64, pub max_position_embeddings: usize, pub type_vocab_size: usize, pub initializer_range: f64, pub layer_norm_eps: f64, pub relative_attention: bool, pub max_relative_positions: isize, pub pad_token_id: Option, pub position_biased_input: bool, #[serde(deserialize_with = "deserialize_pos_att_type")] pub pos_att_type: Vec, pub position_buckets: Option, pub share_att_key: Option, pub attention_head_size: Option, pub embedding_size: Option, pub norm_rel_ebd: Option, pub conv_kernel_size: Option, pub conv_groups: Option, pub conv_act: Option, pub id2label: Option, pub label2id: Option, pub pooler_dropout: Option, pub pooler_hidden_act: Option, pub pooler_hidden_size: Option, pub cls_dropout: Option, } fn deserialize_pos_att_type<'de, D>(deserializer: D) -> std::result::Result, D::Error> where D: Deserializer<'de>, { #[derive(Deserialize, Debug)] #[serde(untagged)] enum StringOrVec { String(String), Vec(Vec), } match StringOrVec::deserialize(deserializer)? { StringOrVec::String(s) => Ok(s.split('|').map(String::from).collect()), StringOrVec::Vec(v) => Ok(v), } } // NOTE: Dropout is probably not needed for now since this will primarily be used // in inferencing. However, for training/fine-tuning it will be necessary. pub struct StableDropout { _drop_prob: f64, _count: usize, } impl StableDropout { pub fn new(drop_prob: f64) -> Self { Self { _drop_prob: drop_prob, _count: 0, } } pub fn forward(&self, x: &Tensor) -> Result { Ok(x.clone()) } } // https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L823 pub struct DebertaV2Embeddings { device: Device, word_embeddings: Embedding, position_embeddings: Option, token_type_embeddings: Option, layer_norm: LayerNorm, dropout: StableDropout, position_ids: Tensor, config: Config, embedding_size: usize, embed_proj: Option, } impl DebertaV2Embeddings { pub fn load(vb: VarBuilder, config: &Config) -> Result { let device = vb.device().clone(); let config = config.clone(); let embedding_size = config.embedding_size.unwrap_or(config.hidden_size); let word_embeddings = embedding(config.vocab_size, embedding_size, vb.pp("word_embeddings"))?; let position_embeddings = if config.position_biased_input { Some(embedding( config.max_position_embeddings, embedding_size, vb.pp("position_embeddings"), )?) } else { None }; let token_type_embeddings: Option = if config.type_vocab_size > 0 { Some(candle_nn::embedding( config.type_vocab_size, config.hidden_size, vb.pp("token_type_embeddings"), )?) } else { None }; let embed_proj: Option = if embedding_size != config.hidden_size { Some(candle_nn::linear_no_bias( embedding_size, config.hidden_size, vb.pp("embed_proj"), )?) } else { None }; let layer_norm = layer_norm( config.hidden_size, config.layer_norm_eps, vb.pp("LayerNorm"), )?; let dropout = StableDropout::new(config.hidden_dropout_prob); let position_ids = Tensor::arange(0, config.max_position_embeddings as u32, &device)?.unsqueeze(0)?; Ok(Self { word_embeddings, position_embeddings, token_type_embeddings, layer_norm, dropout, position_ids, device, config, embedding_size, embed_proj, }) } pub fn forward( &self, input_ids: Option<&Tensor>, token_type_ids: Option<&Tensor>, position_ids: Option<&Tensor>, mask: Option<&Tensor>, inputs_embeds: Option<&Tensor>, ) -> Result { let (input_shape, input_embeds) = match (input_ids, inputs_embeds) { (Some(ids), None) => { let embs = self.word_embeddings.forward(ids)?; (ids.dims(), embs) } (None, Some(e)) => (e.dims(), e.clone()), (None, None) => { bail!("Must specify either input_ids or inputs_embeds") } (Some(_), Some(_)) => { bail!("Can't specify both input_ids and inputs_embeds") } }; let seq_length = match input_shape.last() { Some(v) => *v, None => bail!("DebertaV2Embeddings invalid input shape"), }; let position_ids = match position_ids { Some(v) => v.clone(), None => self.position_ids.narrow(1, 0, seq_length)?, }; let token_type_ids = match token_type_ids { Some(ids) => ids.clone(), None => Tensor::zeros(input_shape, DType::U32, &self.device)?, }; let position_embeddings = match &self.position_embeddings { Some(emb) => emb.forward(&position_ids)?, None => Tensor::zeros_like(&input_embeds)?, }; let mut embeddings = input_embeds; if self.config.position_biased_input { embeddings = embeddings.add(&position_embeddings)?; } if self.config.type_vocab_size > 0 { embeddings = self.token_type_embeddings.as_ref().map_or_else( || bail!("token_type_embeddings must be set when type_vocab_size > 0"), |token_type_embeddings| { embeddings.add(&token_type_embeddings.forward(&token_type_ids)?) }, )?; } if self.embedding_size != self.config.hidden_size { embeddings = if let Some(embed_proj) = &self.embed_proj { embed_proj.forward(&embeddings)? } else { bail!("embed_proj must exist if embedding_size != config.hidden_size"); } } embeddings = self.layer_norm.forward(&embeddings)?; if let Some(mask) = mask { let mut mask = mask.clone(); if mask.dims() != embeddings.dims() { if mask.dims().len() == 4 { mask = mask.squeeze(1)?.squeeze(1)?; } mask = mask.unsqueeze(2)?; } mask = mask.to_dtype(embeddings.dtype())?; embeddings = embeddings.broadcast_mul(&mask)?; } self.dropout.forward(&embeddings) } } // https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L72 struct XSoftmax {} impl XSoftmax { pub fn apply(input: &Tensor, mask: &Tensor, dim: D, device: &Device) -> Result { // NOTE: At the time of this writing, candle does not have a logical-not operator. let mut rmask = mask.broadcast_as(input.shape())?.to_dtype(DType::F32)?; rmask = rmask .broadcast_lt(&Tensor::new(&[1.0_f32], device)?)? .to_dtype(DType::U8)?; let min_value_tensor = Tensor::new(&[f32::MIN], device)?.broadcast_as(input.shape())?; let mut output = rmask.where_cond(&min_value_tensor, input)?; output = candle_nn::ops::softmax(&output, dim)?; let t_zeroes = Tensor::new(&[0f32], device)?.broadcast_as(input.shape())?; output = rmask.where_cond(&t_zeroes, &output)?; Ok(output) } } // https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L605 pub struct DebertaV2DisentangledSelfAttention { config: Config, num_attention_heads: usize, query_proj: candle_nn::Linear, key_proj: candle_nn::Linear, value_proj: candle_nn::Linear, dropout: StableDropout, device: Device, relative_attention: bool, pos_dropout: Option, position_buckets: isize, max_relative_positions: isize, pos_ebd_size: isize, share_att_key: bool, pos_key_proj: Option, pos_query_proj: Option, } impl DebertaV2DisentangledSelfAttention { pub fn load(vb: VarBuilder, config: &Config) -> Result { let config = config.clone(); let vb = vb.clone(); if config.hidden_size % config.num_attention_heads != 0 { return Err(candle::Error::Msg(format!( "The hidden size {} is not a multiple of the number of attention heads {}", config.hidden_size, config.num_attention_heads ))); } let num_attention_heads = config.num_attention_heads; let attention_head_size = config .attention_head_size .unwrap_or(config.hidden_size / config.num_attention_heads); let all_head_size = num_attention_heads * attention_head_size; let query_proj = candle_nn::linear(config.hidden_size, all_head_size, vb.pp("query_proj"))?; let key_proj = candle_nn::linear(config.hidden_size, all_head_size, vb.pp("key_proj"))?; let value_proj = candle_nn::linear(config.hidden_size, all_head_size, vb.pp("value_proj"))?; let share_att_key = config.share_att_key.unwrap_or(false); let relative_attention = config.relative_attention; let mut max_relative_positions = config.max_relative_positions; let mut pos_ebd_size: isize = 0; let position_buckets = config.position_buckets.unwrap_or(-1); let mut pos_dropout: Option = None; let mut pos_key_proj: Option = None; let mut pos_query_proj: Option = None; if relative_attention { if max_relative_positions < 1 { max_relative_positions = config.max_position_embeddings as isize; } pos_ebd_size = max_relative_positions; if position_buckets > 0 { pos_ebd_size = position_buckets } pos_dropout = Some(StableDropout::new(config.hidden_dropout_prob)); if !share_att_key { if config.pos_att_type.iter().any(|s| s == "c2p") { pos_key_proj = Some(candle_nn::linear( config.hidden_size, all_head_size, vb.pp("pos_key_proj"), )?); } if config.pos_att_type.iter().any(|s| s == "p2c") { pos_query_proj = Some(candle_nn::linear( config.hidden_size, all_head_size, vb.pp("pos_query_proj"), )?); } } } let dropout = StableDropout::new(config.attention_probs_dropout_prob); let device = vb.device().clone(); Ok(Self { config, num_attention_heads, query_proj, key_proj, value_proj, dropout, device, relative_attention, pos_dropout, position_buckets, max_relative_positions, pos_ebd_size, share_att_key, pos_key_proj, pos_query_proj, }) } pub fn forward( &self, hidden_states: &Tensor, attention_mask: &Tensor, query_states: Option<&Tensor>, relative_pos: Option<&Tensor>, rel_embeddings: Option<&Tensor>, ) -> Result { let query_states = match query_states { Some(qs) => qs, None => hidden_states, }; let query_layer = self.transpose_for_scores(&self.query_proj.forward(query_states)?)?; let key_layer = self.transpose_for_scores(&self.key_proj.forward(query_states)?)?; let value_layer = self.transpose_for_scores(&self.value_proj.forward(query_states)?)?; let mut rel_att: Option = None; let mut scale_factor: usize = 1; if self.config.pos_att_type.iter().any(|s| s == "c2p") { scale_factor += 1; } if self.config.pos_att_type.iter().any(|s| s == "p2c") { scale_factor += 1; } let scale = { let q_size = query_layer.dim(D::Minus1)?; Tensor::new(&[(q_size * scale_factor) as f32], &self.device)?.sqrt()? }; let mut attention_scores: Tensor = { let key_layer_transposed = key_layer.t()?; let div = key_layer_transposed .broadcast_div(scale.to_dtype(query_layer.dtype())?.as_ref())?; query_layer.matmul(&div)? }; if self.relative_attention { if let Some(rel_embeddings) = rel_embeddings { let rel_embeddings = self .pos_dropout .as_ref() .context("relative_attention requires pos_dropout")? .forward(rel_embeddings)?; rel_att = Some(self.disentangled_attention_bias( query_layer, key_layer, relative_pos, rel_embeddings, scale_factor, )?); } } if let Some(rel_att) = rel_att { attention_scores = attention_scores.broadcast_add(&rel_att)?; } attention_scores = attention_scores.reshape(( (), self.num_attention_heads, attention_scores.dim(D::Minus2)?, attention_scores.dim(D::Minus1)?, ))?; let mut attention_probs = XSoftmax::apply(&attention_scores, attention_mask, D::Minus1, &self.device)?; attention_probs = self.dropout.forward(&attention_probs)?; let mut context_layer = attention_probs .reshape(( (), attention_probs.dim(D::Minus2)?, attention_probs.dim(D::Minus1)?, ))? .matmul(&value_layer)?; context_layer = context_layer .reshape(( (), self.num_attention_heads, context_layer.dim(D::Minus2)?, context_layer.dim(D::Minus1)?, ))? .permute((0, 2, 1, 3))? .contiguous()?; let dims = context_layer.dims(); context_layer = match dims.len() { 2 => context_layer.reshape(())?, 3 => context_layer.reshape((dims[0], ()))?, 4 => context_layer.reshape((dims[0], dims[1], ()))?, 5 => context_layer.reshape((dims[0], dims[1], dims[2], ()))?, _ => { bail!( "Invalid shape for DisentabgledSelfAttention context layer: {:?}", dims ) } }; Ok(context_layer) } fn transpose_for_scores(&self, xs: &Tensor) -> Result { let dims = xs.dims().to_vec(); match dims.len() { 3 => { let reshaped = xs.reshape((dims[0], dims[1], self.num_attention_heads, ()))?; reshaped.transpose(1, 2)?.contiguous()?.reshape(( (), reshaped.dim(1)?, reshaped.dim(D::Minus1)?, )) } shape => { bail!("Invalid shape for transpose_for_scores. Expected 3 dimensions, got {shape}") } } } fn disentangled_attention_bias( &self, query_layer: Tensor, key_layer: Tensor, relative_pos: Option<&Tensor>, rel_embeddings: Tensor, scale_factor: usize, ) -> Result { let mut relative_pos = relative_pos.map_or( build_relative_position( query_layer.dim(D::Minus2)?, key_layer.dim(D::Minus2)?, &self.device, Some(self.position_buckets), Some(self.max_relative_positions), )?, |pos| pos.clone(), ); relative_pos = match relative_pos.dims().len() { 2 => relative_pos.unsqueeze(0)?.unsqueeze(0)?, 3 => relative_pos.unsqueeze(1)?, other => { bail!("Relative position ids must be of dim 2 or 3 or 4. Got dim of size {other}") } }; let att_span = self.pos_ebd_size; let rel_embeddings = rel_embeddings .narrow(0, 0, (att_span * 2) as usize)? .unsqueeze(0)?; let mut pos_query_layer: Option = None; let mut pos_key_layer: Option = None; let repeat_with = query_layer.dim(0)? / self.num_attention_heads; if self.share_att_key { pos_query_layer = Some( self.transpose_for_scores(&self.query_proj.forward(&rel_embeddings)?)? .repeat(repeat_with)?, ); pos_key_layer = Some( self.transpose_for_scores(&self.key_proj.forward(&rel_embeddings)?)? .repeat(repeat_with)?, ) } else { if self.config.pos_att_type.iter().any(|s| s == "c2p") { pos_key_layer = Some( self.transpose_for_scores( &self .pos_key_proj .as_ref() .context( "Need pos_key_proj when share_att_key is false or not specified", )? .forward(&rel_embeddings)?, )? .repeat(repeat_with)?, ) } if self.config.pos_att_type.iter().any(|s| s == "p2c") { pos_query_layer = Some(self.transpose_for_scores(&self .pos_query_proj .as_ref() .context("Need a pos_query_proj when share_att_key is false or not specified")? .forward(&rel_embeddings)?)?.repeat(repeat_with)?) } } let mut score = Tensor::new(&[0 as f32], &self.device)?; if self.config.pos_att_type.iter().any(|s| s == "c2p") { let pos_key_layer = pos_key_layer.context("c2p without pos_key_layer")?; let scale = Tensor::new( &[(pos_key_layer.dim(D::Minus1)? * scale_factor) as f32], &self.device, )? .sqrt()?; let mut c2p_att = query_layer.matmul(&pos_key_layer.t()?)?; let c2p_pos = relative_pos .broadcast_add(&Tensor::new(&[att_span as i64], &self.device)?)? .clamp(0 as f32, (att_span * 2 - 1) as f32)?; c2p_att = c2p_att.gather( &c2p_pos .squeeze(0)? .expand(&[ query_layer.dim(0)?, query_layer.dim(1)?, relative_pos.dim(D::Minus1)?, ])? .contiguous()?, D::Minus1, )?; score = score.broadcast_add( &c2p_att.broadcast_div(scale.to_dtype(c2p_att.dtype())?.as_ref())?, )?; } if self.config.pos_att_type.iter().any(|s| s == "p2c") { let pos_query_layer = pos_query_layer.context("p2c without pos_key_layer")?; let scale = Tensor::new( &[(pos_query_layer.dim(D::Minus1)? * scale_factor) as f32], &self.device, )? .sqrt()?; let r_pos = { if key_layer.dim(D::Minus2)? != query_layer.dim(D::Minus2)? { build_relative_position( key_layer.dim(D::Minus2)?, key_layer.dim(D::Minus2)?, &self.device, Some(self.position_buckets), Some(self.max_relative_positions), )? .unsqueeze(0)? } else { relative_pos } }; let p2c_pos = r_pos .to_dtype(DType::F32)? .neg()? .broadcast_add(&Tensor::new(&[att_span as f32], &self.device)?)? .clamp(0f32, (att_span * 2 - 1) as f32)?; let p2c_att = key_layer .matmul(&pos_query_layer.t()?)? .gather( &p2c_pos .squeeze(0)? .expand(&[ query_layer.dim(0)?, key_layer.dim(D::Minus2)?, key_layer.dim(D::Minus2)?, ])? .contiguous()? .to_dtype(DType::U32)?, D::Minus1, )? .t()?; score = score.broadcast_add(&p2c_att.broadcast_div(&scale.to_dtype(p2c_att.dtype())?)?)?; } Ok(score) } } // https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L270 pub struct DebertaV2Attention { dsa: DebertaV2DisentangledSelfAttention, output: DebertaV2SelfOutput, } impl DebertaV2Attention { pub fn load(vb: VarBuilder, config: &Config) -> Result { let dsa = DebertaV2DisentangledSelfAttention::load(vb.pp("attention.self"), config)?; let output = DebertaV2SelfOutput::load(vb.pp("attention.output"), config)?; Ok(Self { dsa, output }) } fn forward( &self, hidden_states: &Tensor, attention_mask: &Tensor, query_states: Option<&Tensor>, relative_pos: Option<&Tensor>, rel_embeddings: Option<&Tensor>, ) -> Result { let self_output = self.dsa.forward( hidden_states, attention_mask, query_states, relative_pos, rel_embeddings, )?; self.output .forward(&self_output, query_states.unwrap_or(hidden_states)) } } // https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L255 pub struct DebertaV2SelfOutput { dense: candle_nn::Linear, layer_norm: LayerNorm, dropout: StableDropout, } impl DebertaV2SelfOutput { pub fn load(vb: VarBuilder, config: &Config) -> Result { let dense = candle_nn::linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?; let layer_norm = candle_nn::layer_norm( config.hidden_size, config.layer_norm_eps, vb.pp("LayerNorm"), )?; let dropout = StableDropout::new(config.hidden_dropout_prob); Ok(Self { dense, layer_norm, dropout, }) } pub fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { let mut hidden_states = self.dense.forward(hidden_states)?; hidden_states = self.dropout.forward(&hidden_states)?; self.layer_norm .forward(&hidden_states.broadcast_add(input_tensor)?) } } // https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L307 pub struct DebertaV2Intermediate { dense: candle_nn::Linear, intermediate_act: HiddenActLayer, } impl DebertaV2Intermediate { pub fn load(vb: VarBuilder, config: &Config) -> Result { let dense = candle_nn::linear( config.hidden_size, config.intermediate_size, vb.pp("intermediate.dense"), )?; let intermediate_act = HiddenActLayer::new(config.hidden_act); Ok(Self { dense, intermediate_act, }) } pub fn forward(&self, hidden_states: &Tensor) -> Result { self.intermediate_act .forward(&self.dense.forward(hidden_states)?) } } // https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L323 pub struct DebertaV2Output { dense: candle_nn::Linear, layer_norm: LayerNorm, dropout: StableDropout, } impl DebertaV2Output { pub fn load(vb: VarBuilder, config: &Config) -> Result { let dense = candle_nn::linear( config.intermediate_size, config.hidden_size, vb.pp("output.dense"), )?; let layer_norm = candle_nn::layer_norm( config.hidden_size, config.layer_norm_eps, vb.pp("output.LayerNorm"), )?; let dropout = StableDropout::new(config.hidden_dropout_prob); Ok(Self { dense, layer_norm, dropout, }) } pub fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { let mut hidden_states = self.dense.forward(hidden_states)?; hidden_states = self.dropout.forward(&hidden_states)?; hidden_states = { let to_norm = hidden_states.broadcast_add(input_tensor)?; self.layer_norm.forward(&to_norm)? }; Ok(hidden_states) } } // https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L339 pub struct DebertaV2Layer { attention: DebertaV2Attention, intermediate: DebertaV2Intermediate, output: DebertaV2Output, } impl DebertaV2Layer { pub fn load(vb: VarBuilder, config: &Config) -> Result { let attention = DebertaV2Attention::load(vb.clone(), config)?; let intermediate = DebertaV2Intermediate::load(vb.clone(), config)?; let output = DebertaV2Output::load(vb.clone(), config)?; Ok(Self { attention, intermediate, output, }) } fn forward( &self, hidden_states: &Tensor, attention_mask: &Tensor, query_states: Option<&Tensor>, relative_pos: Option<&Tensor>, rel_embeddings: Option<&Tensor>, ) -> Result { let attention_output = self.attention.forward( hidden_states, attention_mask, query_states, relative_pos, rel_embeddings, )?; let intermediate_output = self.intermediate.forward(&attention_output)?; let layer_output = self .output .forward(&intermediate_output, &attention_output)?; Ok(layer_output) } } // TODO: In order to fully test ConvLayer a model needs to be found has a configuration where `conv_kernel_size` exists and is > 0 // https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L373 pub struct ConvLayer { _conv_act: String, _conv: Conv1d, _layer_norm: LayerNorm, _dropout: StableDropout, _config: Config, } impl ConvLayer { pub fn load(vb: VarBuilder, config: &Config) -> Result { let config = config.clone(); let kernel_size = config.conv_kernel_size.unwrap_or(3); let groups = config.conv_groups.unwrap_or(1); let conv_act: String = config.conv_act.clone().unwrap_or("tanh".to_string()); let conv_conf = Conv1dConfig { padding: (kernel_size - 1) / 2, groups, ..Default::default() }; let conv = conv1d( config.hidden_size, config.hidden_size, kernel_size, conv_conf, vb.pp("conv"), )?; let layer_norm = layer_norm( config.hidden_size, config.layer_norm_eps, vb.pp("LayerNorm"), )?; let dropout = StableDropout::new(config.hidden_dropout_prob); Ok(Self { _conv_act: conv_act, _conv: conv, _layer_norm: layer_norm, _dropout: dropout, _config: config, }) } pub fn forward( &self, _hidden_states: &Tensor, _residual_states: &Tensor, _input_mask: &Tensor, ) -> Result { todo!("Need a model that contains a conv layer to test against.") } } // https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L409 pub struct DebertaV2Encoder { layer: Vec, relative_attention: bool, max_relative_positions: isize, position_buckets: isize, rel_embeddings: Option, norm_rel_ebd: String, layer_norm: Option, conv: Option, device: Device, } impl DebertaV2Encoder { pub fn load(vb: VarBuilder, config: &Config) -> Result { let layer = (0..config.num_hidden_layers) .map(|index| DebertaV2Layer::load(vb.pp(format!("layer.{index}")), config)) .collect::>>()?; let relative_attention = config.relative_attention; let mut max_relative_positions = config.max_relative_positions; let position_buckets = config.position_buckets.unwrap_or(-1); let mut rel_embeddings: Option = None; if relative_attention { if max_relative_positions < 1 { max_relative_positions = config.max_position_embeddings as isize; } let mut pos_ebd_size = max_relative_positions * 2; if position_buckets > 0 { pos_ebd_size = position_buckets * 2; } rel_embeddings = Some(embedding( pos_ebd_size as usize, config.hidden_size, vb.pp("rel_embeddings"), )?); } // NOTE: The Python code assumes that the config attribute "norm_rel_ebd" is an array of some kind, but most examples have it as a string. // So it might need to be updated at some point. let norm_rel_ebd = match config.norm_rel_ebd.as_ref() { Some(nre) => nre.trim().to_string(), None => "none".to_string(), }; let layer_norm: Option = if norm_rel_ebd == "layer_norm" { Some(layer_norm( config.hidden_size, config.layer_norm_eps, vb.pp("LayerNorm"), )?) } else { None }; let conv: Option = if config.conv_kernel_size.unwrap_or(0) > 0 { Some(ConvLayer::load(vb.pp("conv"), config)?) } else { None }; Ok(Self { layer, relative_attention, max_relative_positions, position_buckets, rel_embeddings, norm_rel_ebd, layer_norm, conv, device: vb.device().clone(), }) } pub fn forward( &self, hidden_states: &Tensor, attention_mask: &Tensor, query_states: Option<&Tensor>, relative_pos: Option<&Tensor>, ) -> Result { let input_mask = if attention_mask.dims().len() <= 2 { attention_mask.clone() } else { attention_mask .sum_keepdim(attention_mask.rank() - 2)? .gt(0.)? }; let attention_mask = self.get_attention_mask(attention_mask.clone())?; let relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)?; let mut next_kv: Tensor = hidden_states.clone(); let rel_embeddings = self.get_rel_embedding()?; let mut output_states = next_kv.to_owned(); let mut query_states: Option = query_states.cloned(); for (i, layer_module) in self.layer.iter().enumerate() { // NOTE: The original python code branches here if this model is being // used for training vs. inferencing. For now, we will only handle the // inferencing side of things output_states = layer_module.forward( next_kv.as_ref(), &attention_mask, query_states.as_ref(), relative_pos.as_ref(), rel_embeddings.as_ref(), )?; if i == 0 { if let Some(conv) = &self.conv { output_states = conv.forward(hidden_states, &output_states, &input_mask)?; } } if query_states.is_some() { query_states = Some(output_states.clone()); } else { next_kv = output_states.clone(); } } Ok(output_states) } fn get_attention_mask(&self, mut attention_mask: Tensor) -> Result { match attention_mask.dims().len() { 0..=2 => { let extended_attention_mask = attention_mask.unsqueeze(1)?.unsqueeze(2)?; attention_mask = extended_attention_mask.broadcast_mul( &extended_attention_mask .squeeze(D::Minus2)? .unsqueeze(D::Minus1)?, )?; } 3 => attention_mask = attention_mask.unsqueeze(1)?, len => bail!("Unsupported attentiom mask size length: {len}"), } Ok(attention_mask) } fn get_rel_pos( &self, hidden_states: &Tensor, query_states: Option<&Tensor>, relative_pos: Option<&Tensor>, ) -> Result> { if self.relative_attention && relative_pos.is_none() { let q = if let Some(query_states) = query_states { query_states.dim(D::Minus2)? } else { hidden_states.dim(D::Minus2)? }; return Ok(Some(build_relative_position( q, hidden_states.dim(D::Minus2)?, &self.device, Some(self.position_buckets), Some(self.max_relative_positions), )?)); } if relative_pos.is_some() { Ok(relative_pos.cloned()) } else { Ok(None) } } fn get_rel_embedding(&self) -> Result> { if !self.relative_attention { return Ok(None); } let rel_embeddings = self .rel_embeddings .as_ref() .context("self.rel_embeddings not present when using relative_attention")? .embeddings() .clone(); if !self.norm_rel_ebd.contains("layer_norm") { return Ok(Some(rel_embeddings)); } let layer_normed_embeddings = self .layer_norm .as_ref() .context("DebertaV2Encoder layer_norm is None when norm_rel_ebd contains layer_norm")? .forward(&rel_embeddings)?; Ok(Some(layer_normed_embeddings)) } } // https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L991 pub struct DebertaV2Model { embeddings: DebertaV2Embeddings, encoder: DebertaV2Encoder, z_steps: usize, pub device: Device, } impl DebertaV2Model { pub fn load(vb: VarBuilder, config: &Config) -> Result { let vb = vb.clone(); let embeddings = DebertaV2Embeddings::load(vb.pp("embeddings"), config)?; let encoder = DebertaV2Encoder::load(vb.pp("encoder"), config)?; let z_steps: usize = 0; Ok(Self { embeddings, encoder, z_steps, device: vb.device().clone(), }) } pub fn forward( &self, input_ids: &Tensor, token_type_ids: Option, attention_mask: Option, ) -> Result { let input_ids_shape = input_ids.shape(); let attention_mask = match attention_mask { Some(mask) => mask, None => Tensor::ones(input_ids_shape, DType::I64, &self.device)?, }; let token_type_ids = match token_type_ids { Some(ids) => ids, None => Tensor::zeros(input_ids_shape, DType::U32, &self.device)?, }; let embedding_output = self.embeddings.forward( Some(input_ids), Some(&token_type_ids), None, Some(&attention_mask), None, )?; let encoder_output = self.encoder .forward(&embedding_output, &attention_mask, None, None)?; if self.z_steps > 1 { todo!("Complete DebertaV2Model forward() when z_steps > 1 -- Needs a model to test this situation.") } Ok(encoder_output) } } #[derive(Debug)] pub struct NERItem { pub entity: String, pub word: String, pub score: f32, pub start: usize, pub end: usize, pub index: usize, } #[derive(Debug)] pub struct TextClassificationItem { pub label: String, pub score: f32, } pub struct DebertaV2NERModel { pub device: Device, deberta: DebertaV2Model, dropout: candle_nn::Dropout, classifier: candle_nn::Linear, } fn id2label_len(config: &Config, id2label: Option>) -> Result { let id2label_len = match (&config.id2label, id2label) { (None, None) => bail!("Id2Label is either not present in the model configuration or not passed into DebertaV2NERModel::load as a parameter"), (None, Some(id2label_p)) => id2label_p.len(), (Some(id2label_c), None) => id2label_c.len(), (Some(id2label_c), Some(id2label_p)) => { if *id2label_c == id2label_p { id2label_c.len() } else { bail!("Id2Label is both present in the model configuration and provided as a parameter, and they are different.") } } }; Ok(id2label_len) } impl DebertaV2NERModel { pub fn load(vb: VarBuilder, config: &Config, id2label: Option) -> Result { let id2label_len = id2label_len(config, id2label)?; let deberta = DebertaV2Model::load(vb.clone(), config)?; let dropout = candle_nn::Dropout::new(config.hidden_dropout_prob as f32); let classifier: candle_nn::Linear = candle_nn::linear_no_bias( config.hidden_size, id2label_len, vb.root().pp("classifier"), )?; Ok(Self { device: vb.device().clone(), deberta, dropout, classifier, }) } pub fn forward( &self, input_ids: &Tensor, token_type_ids: Option, attention_mask: Option, ) -> Result { let output = self .deberta .forward(input_ids, token_type_ids, attention_mask)?; let output = self.dropout.forward(&output, false)?; self.classifier.forward(&output) } } pub struct DebertaV2SeqClassificationModel { pub device: Device, deberta: DebertaV2Model, dropout: StableDropout, pooler: DebertaV2ContextPooler, classifier: candle_nn::Linear, } impl DebertaV2SeqClassificationModel { pub fn load(vb: VarBuilder, config: &Config, id2label: Option) -> Result { let id2label_len = id2label_len(config, id2label)?; let deberta = DebertaV2Model::load(vb.clone(), config)?; let pooler = DebertaV2ContextPooler::load(vb.clone(), config)?; let output_dim = pooler.output_dim()?; let classifier = candle_nn::linear(output_dim, id2label_len, vb.root().pp("classifier"))?; let dropout = match config.cls_dropout { Some(cls_dropout) => StableDropout::new(cls_dropout), None => StableDropout::new(config.hidden_dropout_prob), }; Ok(Self { device: vb.device().clone(), deberta, dropout, pooler, classifier, }) } pub fn forward( &self, input_ids: &Tensor, token_type_ids: Option, attention_mask: Option, ) -> Result { let encoder_layer = self .deberta .forward(input_ids, token_type_ids, attention_mask)?; let pooled_output = self.pooler.forward(&encoder_layer)?; let pooled_output = self.dropout.forward(&pooled_output)?; self.classifier.forward(&pooled_output) } } pub struct DebertaV2ContextPooler { dense: candle_nn::Linear, dropout: StableDropout, config: Config, } // https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L49 impl DebertaV2ContextPooler { pub fn load(vb: VarBuilder, config: &Config) -> Result { let pooler_hidden_size = config .pooler_hidden_size .context("config.pooler_hidden_size is required for DebertaV2ContextPooler")?; let pooler_dropout = config .pooler_dropout .context("config.pooler_dropout is required for DebertaV2ContextPooler")?; let dense = candle_nn::linear( pooler_hidden_size, pooler_hidden_size, vb.root().pp("pooler.dense"), )?; let dropout = StableDropout::new(pooler_dropout); Ok(Self { dense, dropout, config: config.clone(), }) } pub fn forward(&self, hidden_states: &Tensor) -> Result { let context_token = hidden_states.narrow(1, 0, 1)?.squeeze(1)?; let context_token = self.dropout.forward(&context_token)?; let pooled_output = self.dense.forward(&context_token.contiguous()?)?; let pooler_hidden_act = self .config .pooler_hidden_act .context("Could not obtain pooler hidden act from config")?; HiddenActLayer::new(pooler_hidden_act).forward(&pooled_output) } pub fn output_dim(&self) -> Result { self.config.pooler_hidden_size.context("DebertaV2ContextPooler cannot return output_dim (pooler_hidden_size) since it is not specified in the model config") } } // https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L557 pub(crate) fn build_relative_position( query_size: usize, key_size: usize, device: &Device, bucket_size: Option, max_position: Option, ) -> Result { let q_ids = Tensor::arange(0, query_size as i64, device)?.unsqueeze(0)?; let k_ids: Tensor = Tensor::arange(0, key_size as i64, device)?.unsqueeze(D::Minus1)?; let mut rel_pos_ids = k_ids.broadcast_sub(&q_ids)?; let bucket_size = bucket_size.unwrap_or(-1); let max_position = max_position.unwrap_or(-1); if bucket_size > 0 && max_position > 0 { rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position, device)?; } rel_pos_ids = rel_pos_ids.to_dtype(DType::I64)?; rel_pos_ids = rel_pos_ids.narrow(0, 0, query_size)?; rel_pos_ids.unsqueeze(0) } // https://github.com/huggingface/transformers/blob/78b2929c0554b79e0489b451ce4ece14d265ead2/src/transformers/models/deberta_v2/modeling_deberta_v2.py#L542 pub(crate) fn make_log_bucket_position( relative_pos: Tensor, bucket_size: isize, max_position: isize, device: &Device, ) -> Result { let sign = relative_pos.to_dtype(DType::F32)?.sign()?; let mid = bucket_size / 2; let lt_mid = relative_pos.lt(mid as i64)?; let gt_neg_mid = relative_pos.gt(-mid as i64)?; let condition = lt_mid .to_dtype(candle::DType::F32)? .mul(>_neg_mid.to_dtype(candle::DType::F32)?)? .to_dtype(DType::U8)?; let on_true = Tensor::new(&[(mid - 1) as u32], device)? .broadcast_as(relative_pos.shape())? .to_dtype(relative_pos.dtype())?; let on_false = relative_pos .to_dtype(DType::F32)? .abs()? .to_dtype(DType::I64)?; let abs_pos = condition.where_cond(&on_true, &on_false)?; let mid_as_tensor = Tensor::from_slice(&[mid as f32], (1,), device)?; let log_pos = { let first_log = abs_pos .to_dtype(DType::F32)? .broadcast_div(&mid_as_tensor)? .log()?; let second_log = Tensor::from_slice(&[((max_position as f32 - 1.0) / mid as f32)], (1,), device)? .log()?; let first_div_second = first_log.broadcast_div(&second_log)?; let to_ceil = first_div_second .broadcast_mul(Tensor::from_slice(&[(mid - 1) as f32], (1,), device)?.as_ref())?; let ceil = to_ceil.ceil()?; ceil.broadcast_add(&mid_as_tensor)? }; Ok({ let abs_pos_lte_mid = abs_pos.to_dtype(DType::F32)?.broadcast_le(&mid_as_tensor)?; let relative_pos = relative_pos.to_dtype(relative_pos.dtype())?; let log_pos_mul_sign = log_pos.broadcast_mul(&sign.to_dtype(DType::F32)?)?; abs_pos_lte_mid.where_cond(&relative_pos.to_dtype(DType::F32)?, &log_pos_mul_sign)? }) }