mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Model creation.
This commit is contained in:
@ -1,9 +1,11 @@
|
|||||||
#![allow(dead_code)]
|
#![allow(dead_code)]
|
||||||
|
|
||||||
use anyhow::Result as R;
|
use anyhow::Result as R;
|
||||||
use candle::{Result, Tensor};
|
use candle::{DType, Device, Result, Tensor};
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
const DTYPE: DType = DType::F32;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
enum HiddenAct {
|
enum HiddenAct {
|
||||||
Gelu,
|
Gelu,
|
||||||
Relu,
|
Relu,
|
||||||
@ -18,7 +20,7 @@ impl HiddenAct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
enum PositionEmbeddingType {
|
enum PositionEmbeddingType {
|
||||||
Absolute,
|
Absolute,
|
||||||
}
|
}
|
||||||
@ -94,11 +96,13 @@ impl Linear {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Dropout {}
|
struct Dropout {
|
||||||
|
pr: f64,
|
||||||
|
}
|
||||||
|
|
||||||
impl Dropout {
|
impl Dropout {
|
||||||
fn new() -> Self {
|
fn new(pr: f64) -> Self {
|
||||||
Self {}
|
Self { pr }
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
@ -140,6 +144,31 @@ struct BertEmbeddings {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl BertEmbeddings {
|
impl BertEmbeddings {
|
||||||
|
fn load(device: &Device, config: &Config) -> Result<Self> {
|
||||||
|
let word_embeddings =
|
||||||
|
Tensor::zeros((config.vocab_size, config.hidden_size), DTYPE, device)?;
|
||||||
|
let position_embeddings = Tensor::zeros(
|
||||||
|
(config.max_position_embeddings, config.hidden_size),
|
||||||
|
DTYPE,
|
||||||
|
device,
|
||||||
|
)?;
|
||||||
|
let token_type_embeddings =
|
||||||
|
Tensor::zeros((config.type_vocab_size, config.hidden_size), DTYPE, device)?;
|
||||||
|
let layer_norm = Tensor::zeros((), DTYPE, device)?;
|
||||||
|
let position_ids: Vec<_> = (0..config.max_position_embeddings as u32).collect();
|
||||||
|
let position_ids = Tensor::new(&position_ids[..], device)?.unsqueeze(0)?;
|
||||||
|
let token_type_ids = position_ids.zeros_like()?;
|
||||||
|
Ok(Self {
|
||||||
|
word_embeddings: Embedding::new(word_embeddings),
|
||||||
|
position_embeddings: Some(Embedding::new(position_embeddings)),
|
||||||
|
token_type_embeddings: Embedding::new(token_type_embeddings),
|
||||||
|
layer_norm: LayerNorm::new(layer_norm),
|
||||||
|
dropout: Dropout::new(config.hidden_dropout_prob),
|
||||||
|
position_ids,
|
||||||
|
token_type_ids,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
|
fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
|
||||||
let input_embeddings = self.word_embeddings.forward(input_ids)?;
|
let input_embeddings = self.word_embeddings.forward(input_ids)?;
|
||||||
let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
|
let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
|
||||||
@ -163,6 +192,26 @@ struct BertSelfAttention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl BertSelfAttention {
|
impl BertSelfAttention {
|
||||||
|
fn load(device: &Device, config: &Config) -> Result<Self> {
|
||||||
|
let attention_head_size = config.hidden_size / config.num_attention_heads;
|
||||||
|
let all_head_size = config.num_attention_heads * attention_head_size;
|
||||||
|
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||||
|
let query = Tensor::zeros((config.hidden_size, all_head_size), DTYPE, device)?;
|
||||||
|
let query = Linear::new(query);
|
||||||
|
let value = Tensor::zeros((config.hidden_size, all_head_size), DTYPE, device)?;
|
||||||
|
let value = Linear::new(value);
|
||||||
|
let key = Tensor::zeros((config.hidden_size, all_head_size), DTYPE, device)?;
|
||||||
|
let key = Linear::new(key);
|
||||||
|
Ok(Self {
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
dropout,
|
||||||
|
num_attention_heads: config.num_attention_heads,
|
||||||
|
attention_head_size,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
|
fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let mut new_x_shape = xs.dims().to_vec();
|
let mut new_x_shape = xs.dims().to_vec();
|
||||||
new_x_shape.pop();
|
new_x_shape.pop();
|
||||||
@ -199,6 +248,19 @@ struct BertSelfOutput {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl BertSelfOutput {
|
impl BertSelfOutput {
|
||||||
|
fn load(device: &Device, config: &Config) -> Result<Self> {
|
||||||
|
let dense = Tensor::zeros((config.hidden_size, config.hidden_size), DTYPE, device)?;
|
||||||
|
let dense = Linear::new(dense);
|
||||||
|
let layer_norm = Tensor::zeros((), DTYPE, device)?;
|
||||||
|
let layer_norm = LayerNorm::new(layer_norm);
|
||||||
|
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||||
|
Ok(Self {
|
||||||
|
dense,
|
||||||
|
layer_norm,
|
||||||
|
dropout,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
|
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
|
||||||
let hidden_states = self.dense.forward(hidden_states)?;
|
let hidden_states = self.dense.forward(hidden_states)?;
|
||||||
let hidden_states = self.dropout.forward(&hidden_states)?;
|
let hidden_states = self.dropout.forward(&hidden_states)?;
|
||||||
@ -213,6 +275,15 @@ struct BertAttention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl BertAttention {
|
impl BertAttention {
|
||||||
|
fn load(device: &Device, config: &Config) -> Result<Self> {
|
||||||
|
let self_attention = BertSelfAttention::load(device, config)?;
|
||||||
|
let self_output = BertSelfOutput::load(device, config)?;
|
||||||
|
Ok(Self {
|
||||||
|
self_attention,
|
||||||
|
self_output,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||||
let self_outputs = self.self_attention.forward(hidden_states)?;
|
let self_outputs = self.self_attention.forward(hidden_states)?;
|
||||||
let attention_output = self.self_output.forward(&self_outputs, hidden_states)?;
|
let attention_output = self.self_output.forward(&self_outputs, hidden_states)?;
|
||||||
@ -227,6 +298,19 @@ struct BertIntermediate {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl BertIntermediate {
|
impl BertIntermediate {
|
||||||
|
fn load(device: &Device, config: &Config) -> Result<Self> {
|
||||||
|
let dense = Tensor::zeros(
|
||||||
|
(config.hidden_size, config.intermediate_size),
|
||||||
|
DTYPE,
|
||||||
|
device,
|
||||||
|
)?;
|
||||||
|
let dense = Linear::new(dense);
|
||||||
|
Ok(Self {
|
||||||
|
dense,
|
||||||
|
intermediate_act: config.hidden_act,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||||
let hidden_states = self.dense.forward(hidden_states)?;
|
let hidden_states = self.dense.forward(hidden_states)?;
|
||||||
self.intermediate_act.forward(&hidden_states)
|
self.intermediate_act.forward(&hidden_states)
|
||||||
@ -241,6 +325,23 @@ struct BertOutput {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl BertOutput {
|
impl BertOutput {
|
||||||
|
fn load(device: &Device, config: &Config) -> Result<Self> {
|
||||||
|
let dense = Tensor::zeros(
|
||||||
|
(config.intermediate_size, config.hidden_size),
|
||||||
|
DTYPE,
|
||||||
|
device,
|
||||||
|
)?;
|
||||||
|
let dense = Linear::new(dense);
|
||||||
|
let layer_norm = Tensor::zeros((), DTYPE, device)?;
|
||||||
|
let layer_norm = LayerNorm::new(layer_norm);
|
||||||
|
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||||
|
Ok(Self {
|
||||||
|
dense,
|
||||||
|
layer_norm,
|
||||||
|
dropout,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
|
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
|
||||||
let hidden_states = self.dense.forward(hidden_states)?;
|
let hidden_states = self.dense.forward(hidden_states)?;
|
||||||
let hidden_states = self.dropout.forward(&hidden_states)?;
|
let hidden_states = self.dropout.forward(&hidden_states)?;
|
||||||
@ -256,6 +357,17 @@ struct BertLayer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl BertLayer {
|
impl BertLayer {
|
||||||
|
fn load(device: &Device, config: &Config) -> Result<Self> {
|
||||||
|
let attention = BertAttention::load(device, config)?;
|
||||||
|
let intermediate = BertIntermediate::load(device, config)?;
|
||||||
|
let output = BertOutput::load(device, config)?;
|
||||||
|
Ok(Self {
|
||||||
|
attention,
|
||||||
|
intermediate,
|
||||||
|
output,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||||
let attention_output = self.attention.forward(hidden_states)?;
|
let attention_output = self.attention.forward(hidden_states)?;
|
||||||
// TODO: Support cross-attention?
|
// TODO: Support cross-attention?
|
||||||
@ -275,6 +387,13 @@ struct BertEncoder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl BertEncoder {
|
impl BertEncoder {
|
||||||
|
fn load(device: &Device, config: &Config) -> Result<Self> {
|
||||||
|
let layers = (0..config.num_hidden_layers)
|
||||||
|
.map(|_index| BertLayer::load(device, config))
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
Ok(BertEncoder { layers })
|
||||||
|
}
|
||||||
|
|
||||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||||
let mut hidden_states = hidden_states.clone();
|
let mut hidden_states = hidden_states.clone();
|
||||||
// Use a loop rather than a fold as it's easier to modify when adding debug/...
|
// Use a loop rather than a fold as it's easier to modify when adding debug/...
|
||||||
@ -292,6 +411,15 @@ struct BertModel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl BertModel {
|
impl BertModel {
|
||||||
|
fn load(device: &Device, config: &Config) -> Result<Self> {
|
||||||
|
let embeddings = BertEmbeddings::load(device, config)?;
|
||||||
|
let encoder = BertEncoder::load(device, config)?;
|
||||||
|
Ok(Self {
|
||||||
|
embeddings,
|
||||||
|
encoder,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn forward(&self, input_ids: &Tensor, position_ids: &Tensor) -> Result<Tensor> {
|
fn forward(&self, input_ids: &Tensor, position_ids: &Tensor) -> Result<Tensor> {
|
||||||
let embedding_output = self.embeddings.forward(input_ids, position_ids)?;
|
let embedding_output = self.embeddings.forward(input_ids, position_ids)?;
|
||||||
let sequence_output = self.encoder.forward(&embedding_output)?;
|
let sequence_output = self.encoder.forward(&embedding_output)?;
|
||||||
|
Reference in New Issue
Block a user