mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add more to the forward pass.
This commit is contained in:
@ -131,12 +131,28 @@ impl LayerNorm {
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180
|
||||
struct BertEmbeddings {
|
||||
word_embeddings: Embedding,
|
||||
position_embeddings: Embedding,
|
||||
position_embeddings: Option<Embedding>,
|
||||
token_type_embeddings: Embedding,
|
||||
layer_norm: LayerNorm,
|
||||
dropout: Dropout,
|
||||
position_ids: Tensor,
|
||||
token_type_ids: Tensor,
|
||||
}
|
||||
|
||||
impl BertEmbeddings {
|
||||
fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
|
||||
let input_embeddings = self.word_embeddings.forward(input_ids)?;
|
||||
let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
|
||||
let mut embeddings = (input_embeddings + token_type_embeddings)?;
|
||||
if let Some(position_embeddings) = &self.position_embeddings {
|
||||
embeddings = (&embeddings + position_embeddings.forward(&embeddings))?
|
||||
}
|
||||
let embeddings = self.layer_norm.forward(&embeddings)?;
|
||||
let embeddings = self.dropout.forward(&embeddings)?;
|
||||
Ok(embeddings)
|
||||
}
|
||||
}
|
||||
|
||||
struct BertSelfAttention {
|
||||
query: Linear,
|
||||
key: Linear,
|
||||
@ -144,6 +160,12 @@ struct BertSelfAttention {
|
||||
dropout: Dropout,
|
||||
}
|
||||
|
||||
impl BertSelfAttention {
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
struct BertSelfOutput {
|
||||
dense: Linear,
|
||||
layer_norm: LayerNorm,
|
||||
@ -165,8 +187,10 @@ struct BertAttention {
|
||||
}
|
||||
|
||||
impl BertAttention {
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let self_outputs = self.self_attention.forward(hidden_states)?;
|
||||
let attention_output = self.self_output.forward(&self_outputs, hidden_states)?;
|
||||
Ok(attention_output)
|
||||
}
|
||||
}
|
||||
|
||||
@ -206,8 +230,16 @@ struct BertLayer {
|
||||
}
|
||||
|
||||
impl BertLayer {
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let attention_output = self.attention.forward(hidden_states)?;
|
||||
// TODO: Support cross-attention?
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
|
||||
// TODO: Support something similar to `apply_chunking_to_forward`?
|
||||
let intermediate_output = self.intermediate.forward(&attention_output)?;
|
||||
let layer_output = self
|
||||
.output
|
||||
.forward(&intermediate_output, &attention_output)?;
|
||||
Ok(layer_output)
|
||||
}
|
||||
}
|
||||
|
||||
@ -217,8 +249,13 @@ struct BertEncoder {
|
||||
}
|
||||
|
||||
impl BertEncoder {
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let mut hidden_states = hidden_states.clone();
|
||||
// Use a loop rather than a fold as it's easier to modify when adding debug/...
|
||||
for layer in self.layers.iter() {
|
||||
hidden_states = layer.forward(&hidden_states)?
|
||||
}
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
@ -229,8 +266,10 @@ struct BertModel {
|
||||
}
|
||||
|
||||
impl BertModel {
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
fn forward(&self, input_ids: &Tensor, position_ids: &Tensor) -> Result<Tensor> {
|
||||
let embedding_output = self.embeddings.forward(input_ids, position_ids)?;
|
||||
let sequence_output = self.encoder.forward(&embedding_output)?;
|
||||
Ok(sequence_output)
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user