mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Model creation.
This commit is contained in:
@ -1,9 +1,11 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
use anyhow::Result as R;
|
||||
use candle::{Result, Tensor};
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
const DTYPE: DType = DType::F32;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum HiddenAct {
|
||||
Gelu,
|
||||
Relu,
|
||||
@ -18,7 +20,7 @@ impl HiddenAct {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum PositionEmbeddingType {
|
||||
Absolute,
|
||||
}
|
||||
@ -94,11 +96,13 @@ impl Linear {
|
||||
}
|
||||
}
|
||||
|
||||
struct Dropout {}
|
||||
struct Dropout {
|
||||
pr: f64,
|
||||
}
|
||||
|
||||
impl Dropout {
|
||||
fn new() -> Self {
|
||||
Self {}
|
||||
fn new(pr: f64) -> Self {
|
||||
Self { pr }
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
@ -140,6 +144,31 @@ struct BertEmbeddings {
|
||||
}
|
||||
|
||||
impl BertEmbeddings {
|
||||
fn load(device: &Device, config: &Config) -> Result<Self> {
|
||||
let word_embeddings =
|
||||
Tensor::zeros((config.vocab_size, config.hidden_size), DTYPE, device)?;
|
||||
let position_embeddings = Tensor::zeros(
|
||||
(config.max_position_embeddings, config.hidden_size),
|
||||
DTYPE,
|
||||
device,
|
||||
)?;
|
||||
let token_type_embeddings =
|
||||
Tensor::zeros((config.type_vocab_size, config.hidden_size), DTYPE, device)?;
|
||||
let layer_norm = Tensor::zeros((), DTYPE, device)?;
|
||||
let position_ids: Vec<_> = (0..config.max_position_embeddings as u32).collect();
|
||||
let position_ids = Tensor::new(&position_ids[..], device)?.unsqueeze(0)?;
|
||||
let token_type_ids = position_ids.zeros_like()?;
|
||||
Ok(Self {
|
||||
word_embeddings: Embedding::new(word_embeddings),
|
||||
position_embeddings: Some(Embedding::new(position_embeddings)),
|
||||
token_type_embeddings: Embedding::new(token_type_embeddings),
|
||||
layer_norm: LayerNorm::new(layer_norm),
|
||||
dropout: Dropout::new(config.hidden_dropout_prob),
|
||||
position_ids,
|
||||
token_type_ids,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
|
||||
let input_embeddings = self.word_embeddings.forward(input_ids)?;
|
||||
let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
|
||||
@ -163,6 +192,26 @@ struct BertSelfAttention {
|
||||
}
|
||||
|
||||
impl BertSelfAttention {
|
||||
fn load(device: &Device, config: &Config) -> Result<Self> {
|
||||
let attention_head_size = config.hidden_size / config.num_attention_heads;
|
||||
let all_head_size = config.num_attention_heads * attention_head_size;
|
||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||
let query = Tensor::zeros((config.hidden_size, all_head_size), DTYPE, device)?;
|
||||
let query = Linear::new(query);
|
||||
let value = Tensor::zeros((config.hidden_size, all_head_size), DTYPE, device)?;
|
||||
let value = Linear::new(value);
|
||||
let key = Tensor::zeros((config.hidden_size, all_head_size), DTYPE, device)?;
|
||||
let key = Linear::new(key);
|
||||
Ok(Self {
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
dropout,
|
||||
num_attention_heads: config.num_attention_heads,
|
||||
attention_head_size,
|
||||
})
|
||||
}
|
||||
|
||||
fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut new_x_shape = xs.dims().to_vec();
|
||||
new_x_shape.pop();
|
||||
@ -199,6 +248,19 @@ struct BertSelfOutput {
|
||||
}
|
||||
|
||||
impl BertSelfOutput {
|
||||
fn load(device: &Device, config: &Config) -> Result<Self> {
|
||||
let dense = Tensor::zeros((config.hidden_size, config.hidden_size), DTYPE, device)?;
|
||||
let dense = Linear::new(dense);
|
||||
let layer_norm = Tensor::zeros((), DTYPE, device)?;
|
||||
let layer_norm = LayerNorm::new(layer_norm);
|
||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||
Ok(Self {
|
||||
dense,
|
||||
layer_norm,
|
||||
dropout,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
|
||||
let hidden_states = self.dense.forward(hidden_states)?;
|
||||
let hidden_states = self.dropout.forward(&hidden_states)?;
|
||||
@ -213,6 +275,15 @@ struct BertAttention {
|
||||
}
|
||||
|
||||
impl BertAttention {
|
||||
fn load(device: &Device, config: &Config) -> Result<Self> {
|
||||
let self_attention = BertSelfAttention::load(device, config)?;
|
||||
let self_output = BertSelfOutput::load(device, config)?;
|
||||
Ok(Self {
|
||||
self_attention,
|
||||
self_output,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let self_outputs = self.self_attention.forward(hidden_states)?;
|
||||
let attention_output = self.self_output.forward(&self_outputs, hidden_states)?;
|
||||
@ -227,6 +298,19 @@ struct BertIntermediate {
|
||||
}
|
||||
|
||||
impl BertIntermediate {
|
||||
fn load(device: &Device, config: &Config) -> Result<Self> {
|
||||
let dense = Tensor::zeros(
|
||||
(config.hidden_size, config.intermediate_size),
|
||||
DTYPE,
|
||||
device,
|
||||
)?;
|
||||
let dense = Linear::new(dense);
|
||||
Ok(Self {
|
||||
dense,
|
||||
intermediate_act: config.hidden_act,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let hidden_states = self.dense.forward(hidden_states)?;
|
||||
self.intermediate_act.forward(&hidden_states)
|
||||
@ -241,6 +325,23 @@ struct BertOutput {
|
||||
}
|
||||
|
||||
impl BertOutput {
|
||||
fn load(device: &Device, config: &Config) -> Result<Self> {
|
||||
let dense = Tensor::zeros(
|
||||
(config.intermediate_size, config.hidden_size),
|
||||
DTYPE,
|
||||
device,
|
||||
)?;
|
||||
let dense = Linear::new(dense);
|
||||
let layer_norm = Tensor::zeros((), DTYPE, device)?;
|
||||
let layer_norm = LayerNorm::new(layer_norm);
|
||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||
Ok(Self {
|
||||
dense,
|
||||
layer_norm,
|
||||
dropout,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
|
||||
let hidden_states = self.dense.forward(hidden_states)?;
|
||||
let hidden_states = self.dropout.forward(&hidden_states)?;
|
||||
@ -256,6 +357,17 @@ struct BertLayer {
|
||||
}
|
||||
|
||||
impl BertLayer {
|
||||
fn load(device: &Device, config: &Config) -> Result<Self> {
|
||||
let attention = BertAttention::load(device, config)?;
|
||||
let intermediate = BertIntermediate::load(device, config)?;
|
||||
let output = BertOutput::load(device, config)?;
|
||||
Ok(Self {
|
||||
attention,
|
||||
intermediate,
|
||||
output,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let attention_output = self.attention.forward(hidden_states)?;
|
||||
// TODO: Support cross-attention?
|
||||
@ -275,6 +387,13 @@ struct BertEncoder {
|
||||
}
|
||||
|
||||
impl BertEncoder {
|
||||
fn load(device: &Device, config: &Config) -> Result<Self> {
|
||||
let layers = (0..config.num_hidden_layers)
|
||||
.map(|_index| BertLayer::load(device, config))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(BertEncoder { layers })
|
||||
}
|
||||
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let mut hidden_states = hidden_states.clone();
|
||||
// Use a loop rather than a fold as it's easier to modify when adding debug/...
|
||||
@ -292,6 +411,15 @@ struct BertModel {
|
||||
}
|
||||
|
||||
impl BertModel {
|
||||
fn load(device: &Device, config: &Config) -> Result<Self> {
|
||||
let embeddings = BertEmbeddings::load(device, config)?;
|
||||
let encoder = BertEncoder::load(device, config)?;
|
||||
Ok(Self {
|
||||
embeddings,
|
||||
encoder,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, input_ids: &Tensor, position_ids: &Tensor) -> Result<Tensor> {
|
||||
let embedding_output = self.embeddings.forward(input_ids, position_ids)?;
|
||||
let sequence_output = self.encoder.forward(&embedding_output)?;
|
||||
|
Reference in New Issue
Block a user