Model creation.

This commit is contained in:
laurent
2023-07-03 14:09:46 +01:00
parent 12ac9e1460
commit f74bddca31

View File

@ -1,9 +1,11 @@
#![allow(dead_code)]
use anyhow::Result as R;
use candle::{Result, Tensor};
use candle::{DType, Device, Result, Tensor};
#[derive(Debug, Clone, PartialEq, Eq)]
const DTYPE: DType = DType::F32;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum HiddenAct {
Gelu,
Relu,
@ -18,7 +20,7 @@ impl HiddenAct {
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum PositionEmbeddingType {
Absolute,
}
@ -94,11 +96,13 @@ impl Linear {
}
}
struct Dropout {}
struct Dropout {
pr: f64,
}
impl Dropout {
fn new() -> Self {
Self {}
fn new(pr: f64) -> Self {
Self { pr }
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
@ -140,6 +144,31 @@ struct BertEmbeddings {
}
impl BertEmbeddings {
fn load(device: &Device, config: &Config) -> Result<Self> {
let word_embeddings =
Tensor::zeros((config.vocab_size, config.hidden_size), DTYPE, device)?;
let position_embeddings = Tensor::zeros(
(config.max_position_embeddings, config.hidden_size),
DTYPE,
device,
)?;
let token_type_embeddings =
Tensor::zeros((config.type_vocab_size, config.hidden_size), DTYPE, device)?;
let layer_norm = Tensor::zeros((), DTYPE, device)?;
let position_ids: Vec<_> = (0..config.max_position_embeddings as u32).collect();
let position_ids = Tensor::new(&position_ids[..], device)?.unsqueeze(0)?;
let token_type_ids = position_ids.zeros_like()?;
Ok(Self {
word_embeddings: Embedding::new(word_embeddings),
position_embeddings: Some(Embedding::new(position_embeddings)),
token_type_embeddings: Embedding::new(token_type_embeddings),
layer_norm: LayerNorm::new(layer_norm),
dropout: Dropout::new(config.hidden_dropout_prob),
position_ids,
token_type_ids,
})
}
fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
let input_embeddings = self.word_embeddings.forward(input_ids)?;
let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
@ -163,6 +192,26 @@ struct BertSelfAttention {
}
impl BertSelfAttention {
fn load(device: &Device, config: &Config) -> Result<Self> {
let attention_head_size = config.hidden_size / config.num_attention_heads;
let all_head_size = config.num_attention_heads * attention_head_size;
let dropout = Dropout::new(config.hidden_dropout_prob);
let query = Tensor::zeros((config.hidden_size, all_head_size), DTYPE, device)?;
let query = Linear::new(query);
let value = Tensor::zeros((config.hidden_size, all_head_size), DTYPE, device)?;
let value = Linear::new(value);
let key = Tensor::zeros((config.hidden_size, all_head_size), DTYPE, device)?;
let key = Linear::new(key);
Ok(Self {
query,
key,
value,
dropout,
num_attention_heads: config.num_attention_heads,
attention_head_size,
})
}
fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
let mut new_x_shape = xs.dims().to_vec();
new_x_shape.pop();
@ -199,6 +248,19 @@ struct BertSelfOutput {
}
impl BertSelfOutput {
fn load(device: &Device, config: &Config) -> Result<Self> {
let dense = Tensor::zeros((config.hidden_size, config.hidden_size), DTYPE, device)?;
let dense = Linear::new(dense);
let layer_norm = Tensor::zeros((), DTYPE, device)?;
let layer_norm = LayerNorm::new(layer_norm);
let dropout = Dropout::new(config.hidden_dropout_prob);
Ok(Self {
dense,
layer_norm,
dropout,
})
}
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
let hidden_states = self.dense.forward(hidden_states)?;
let hidden_states = self.dropout.forward(&hidden_states)?;
@ -213,6 +275,15 @@ struct BertAttention {
}
impl BertAttention {
fn load(device: &Device, config: &Config) -> Result<Self> {
let self_attention = BertSelfAttention::load(device, config)?;
let self_output = BertSelfOutput::load(device, config)?;
Ok(Self {
self_attention,
self_output,
})
}
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let self_outputs = self.self_attention.forward(hidden_states)?;
let attention_output = self.self_output.forward(&self_outputs, hidden_states)?;
@ -227,6 +298,19 @@ struct BertIntermediate {
}
impl BertIntermediate {
fn load(device: &Device, config: &Config) -> Result<Self> {
let dense = Tensor::zeros(
(config.hidden_size, config.intermediate_size),
DTYPE,
device,
)?;
let dense = Linear::new(dense);
Ok(Self {
dense,
intermediate_act: config.hidden_act,
})
}
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let hidden_states = self.dense.forward(hidden_states)?;
self.intermediate_act.forward(&hidden_states)
@ -241,6 +325,23 @@ struct BertOutput {
}
impl BertOutput {
fn load(device: &Device, config: &Config) -> Result<Self> {
let dense = Tensor::zeros(
(config.intermediate_size, config.hidden_size),
DTYPE,
device,
)?;
let dense = Linear::new(dense);
let layer_norm = Tensor::zeros((), DTYPE, device)?;
let layer_norm = LayerNorm::new(layer_norm);
let dropout = Dropout::new(config.hidden_dropout_prob);
Ok(Self {
dense,
layer_norm,
dropout,
})
}
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
let hidden_states = self.dense.forward(hidden_states)?;
let hidden_states = self.dropout.forward(&hidden_states)?;
@ -256,6 +357,17 @@ struct BertLayer {
}
impl BertLayer {
fn load(device: &Device, config: &Config) -> Result<Self> {
let attention = BertAttention::load(device, config)?;
let intermediate = BertIntermediate::load(device, config)?;
let output = BertOutput::load(device, config)?;
Ok(Self {
attention,
intermediate,
output,
})
}
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let attention_output = self.attention.forward(hidden_states)?;
// TODO: Support cross-attention?
@ -275,6 +387,13 @@ struct BertEncoder {
}
impl BertEncoder {
fn load(device: &Device, config: &Config) -> Result<Self> {
let layers = (0..config.num_hidden_layers)
.map(|_index| BertLayer::load(device, config))
.collect::<Result<Vec<_>>>()?;
Ok(BertEncoder { layers })
}
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let mut hidden_states = hidden_states.clone();
// Use a loop rather than a fold as it's easier to modify when adding debug/...
@ -292,6 +411,15 @@ struct BertModel {
}
impl BertModel {
fn load(device: &Device, config: &Config) -> Result<Self> {
let embeddings = BertEmbeddings::load(device, config)?;
let encoder = BertEncoder::load(device, config)?;
Ok(Self {
embeddings,
encoder,
})
}
fn forward(&self, input_ids: &Tensor, position_ids: &Tensor) -> Result<Tensor> {
let embedding_output = self.embeddings.forward(input_ids, position_ids)?;
let sequence_output = self.encoder.forward(&embedding_output)?;