From f74bddca31a582ad82ad8529c51b823301064847 Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 3 Jul 2023 14:09:46 +0100 Subject: [PATCH] Model creation. --- candle-examples/examples/bert/main.rs | 140 ++++++++++++++++++++++++-- 1 file changed, 134 insertions(+), 6 deletions(-) diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index d3b0e45f..52685c91 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -1,9 +1,11 @@ #![allow(dead_code)] use anyhow::Result as R; -use candle::{Result, Tensor}; +use candle::{DType, Device, Result, Tensor}; -#[derive(Debug, Clone, PartialEq, Eq)] +const DTYPE: DType = DType::F32; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] enum HiddenAct { Gelu, Relu, @@ -18,7 +20,7 @@ impl HiddenAct { } } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] enum PositionEmbeddingType { Absolute, } @@ -94,11 +96,13 @@ impl Linear { } } -struct Dropout {} +struct Dropout { + pr: f64, +} impl Dropout { - fn new() -> Self { - Self {} + fn new(pr: f64) -> Self { + Self { pr } } fn forward(&self, x: &Tensor) -> Result { @@ -140,6 +144,31 @@ struct BertEmbeddings { } impl BertEmbeddings { + fn load(device: &Device, config: &Config) -> Result { + let word_embeddings = + Tensor::zeros((config.vocab_size, config.hidden_size), DTYPE, device)?; + let position_embeddings = Tensor::zeros( + (config.max_position_embeddings, config.hidden_size), + DTYPE, + device, + )?; + let token_type_embeddings = + Tensor::zeros((config.type_vocab_size, config.hidden_size), DTYPE, device)?; + let layer_norm = Tensor::zeros((), DTYPE, device)?; + let position_ids: Vec<_> = (0..config.max_position_embeddings as u32).collect(); + let position_ids = Tensor::new(&position_ids[..], device)?.unsqueeze(0)?; + let token_type_ids = position_ids.zeros_like()?; + Ok(Self { + word_embeddings: Embedding::new(word_embeddings), + position_embeddings: Some(Embedding::new(position_embeddings)), + token_type_embeddings: Embedding::new(token_type_embeddings), + layer_norm: LayerNorm::new(layer_norm), + dropout: Dropout::new(config.hidden_dropout_prob), + position_ids, + token_type_ids, + }) + } + fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result { let input_embeddings = self.word_embeddings.forward(input_ids)?; let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?; @@ -163,6 +192,26 @@ struct BertSelfAttention { } impl BertSelfAttention { + fn load(device: &Device, config: &Config) -> Result { + let attention_head_size = config.hidden_size / config.num_attention_heads; + let all_head_size = config.num_attention_heads * attention_head_size; + let dropout = Dropout::new(config.hidden_dropout_prob); + let query = Tensor::zeros((config.hidden_size, all_head_size), DTYPE, device)?; + let query = Linear::new(query); + let value = Tensor::zeros((config.hidden_size, all_head_size), DTYPE, device)?; + let value = Linear::new(value); + let key = Tensor::zeros((config.hidden_size, all_head_size), DTYPE, device)?; + let key = Linear::new(key); + Ok(Self { + query, + key, + value, + dropout, + num_attention_heads: config.num_attention_heads, + attention_head_size, + }) + } + fn transpose_for_scores(&self, xs: &Tensor) -> Result { let mut new_x_shape = xs.dims().to_vec(); new_x_shape.pop(); @@ -199,6 +248,19 @@ struct BertSelfOutput { } impl BertSelfOutput { + fn load(device: &Device, config: &Config) -> Result { + let dense = Tensor::zeros((config.hidden_size, config.hidden_size), DTYPE, device)?; + let dense = Linear::new(dense); + let layer_norm = Tensor::zeros((), DTYPE, device)?; + let layer_norm = LayerNorm::new(layer_norm); + let dropout = Dropout::new(config.hidden_dropout_prob); + Ok(Self { + dense, + layer_norm, + dropout, + }) + } + fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { let hidden_states = self.dense.forward(hidden_states)?; let hidden_states = self.dropout.forward(&hidden_states)?; @@ -213,6 +275,15 @@ struct BertAttention { } impl BertAttention { + fn load(device: &Device, config: &Config) -> Result { + let self_attention = BertSelfAttention::load(device, config)?; + let self_output = BertSelfOutput::load(device, config)?; + Ok(Self { + self_attention, + self_output, + }) + } + fn forward(&self, hidden_states: &Tensor) -> Result { let self_outputs = self.self_attention.forward(hidden_states)?; let attention_output = self.self_output.forward(&self_outputs, hidden_states)?; @@ -227,6 +298,19 @@ struct BertIntermediate { } impl BertIntermediate { + fn load(device: &Device, config: &Config) -> Result { + let dense = Tensor::zeros( + (config.hidden_size, config.intermediate_size), + DTYPE, + device, + )?; + let dense = Linear::new(dense); + Ok(Self { + dense, + intermediate_act: config.hidden_act, + }) + } + fn forward(&self, hidden_states: &Tensor) -> Result { let hidden_states = self.dense.forward(hidden_states)?; self.intermediate_act.forward(&hidden_states) @@ -241,6 +325,23 @@ struct BertOutput { } impl BertOutput { + fn load(device: &Device, config: &Config) -> Result { + let dense = Tensor::zeros( + (config.intermediate_size, config.hidden_size), + DTYPE, + device, + )?; + let dense = Linear::new(dense); + let layer_norm = Tensor::zeros((), DTYPE, device)?; + let layer_norm = LayerNorm::new(layer_norm); + let dropout = Dropout::new(config.hidden_dropout_prob); + Ok(Self { + dense, + layer_norm, + dropout, + }) + } + fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { let hidden_states = self.dense.forward(hidden_states)?; let hidden_states = self.dropout.forward(&hidden_states)?; @@ -256,6 +357,17 @@ struct BertLayer { } impl BertLayer { + fn load(device: &Device, config: &Config) -> Result { + let attention = BertAttention::load(device, config)?; + let intermediate = BertIntermediate::load(device, config)?; + let output = BertOutput::load(device, config)?; + Ok(Self { + attention, + intermediate, + output, + }) + } + fn forward(&self, hidden_states: &Tensor) -> Result { let attention_output = self.attention.forward(hidden_states)?; // TODO: Support cross-attention? @@ -275,6 +387,13 @@ struct BertEncoder { } impl BertEncoder { + fn load(device: &Device, config: &Config) -> Result { + let layers = (0..config.num_hidden_layers) + .map(|_index| BertLayer::load(device, config)) + .collect::>>()?; + Ok(BertEncoder { layers }) + } + fn forward(&self, hidden_states: &Tensor) -> Result { let mut hidden_states = hidden_states.clone(); // Use a loop rather than a fold as it's easier to modify when adding debug/... @@ -292,6 +411,15 @@ struct BertModel { } impl BertModel { + fn load(device: &Device, config: &Config) -> Result { + let embeddings = BertEmbeddings::load(device, config)?; + let encoder = BertEncoder::load(device, config)?; + Ok(Self { + embeddings, + encoder, + }) + } + fn forward(&self, input_ids: &Tensor, position_ids: &Tensor) -> Result { let embedding_output = self.embeddings.forward(input_ids, position_ids)?; let sequence_output = self.encoder.forward(&embedding_output)?;