mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Add a batch dimension on the bert example.
This commit is contained in:
@ -75,6 +75,9 @@ enum HiddenAct {
|
||||
impl HiddenAct {
|
||||
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
// TODO: The all-MiniLM-L6-v2 model uses "gelu" whereas this is "gelu_new", this explains some
|
||||
// small numerical difference.
|
||||
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
|
||||
Self::Gelu => xs.gelu(),
|
||||
Self::Relu => xs.relu(),
|
||||
}
|
||||
@ -196,7 +199,9 @@ impl Linear {
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = x.matmul(&self.weight.t()?)?;
|
||||
let (bsize, _, _) = x.shape().r3()?;
|
||||
let w = self.weight.broadcast_left(bsize)?.t()?;
|
||||
let x = x.matmul(&w)?;
|
||||
let x = x.broadcast_add(&self.bias)?;
|
||||
Ok(x)
|
||||
}
|
||||
@ -236,12 +241,11 @@ impl LayerNorm {
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let (seq_len, hidden_size) = x.shape().r2()?;
|
||||
let mean_x = (x.sum(&[1])? / hidden_size as f64)?;
|
||||
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
|
||||
let mean_x = (x.sum(&[2])? / hidden_size as f64)?;
|
||||
let x = x.broadcast_sub(&mean_x)?;
|
||||
let norm_x = ((&x * &x)?.sum(&[1])? / hidden_size as f64)?;
|
||||
let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
|
||||
let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
|
||||
let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
|
||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||
let x = x_normed
|
||||
.broadcast_mul(&self.weight)?
|
||||
.broadcast_add(&self.bias)?;
|
||||
@ -301,7 +305,7 @@ impl BertEmbeddings {
|
||||
}
|
||||
|
||||
fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
|
||||
let seq_len = input_ids.shape().r1()?;
|
||||
let (_bsize, seq_len) = input_ids.shape().r2()?;
|
||||
let input_embeddings = self.word_embeddings.forward(input_ids)?;
|
||||
let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
|
||||
let mut embeddings = (&input_embeddings + token_type_embeddings)?;
|
||||
@ -309,7 +313,7 @@ impl BertEmbeddings {
|
||||
// TODO: Proper absolute positions?
|
||||
let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
|
||||
let position_ids = Tensor::new(&position_ids[..], &input_ids.device())?;
|
||||
embeddings = (&embeddings + position_embeddings.forward(&position_ids)?)?
|
||||
embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)?
|
||||
}
|
||||
let embeddings = self.layer_norm.forward(&embeddings)?;
|
||||
let embeddings = self.dropout.forward(&embeddings)?;
|
||||
@ -351,7 +355,7 @@ impl BertSelfAttention {
|
||||
new_x_shape.push(self.num_attention_heads);
|
||||
new_x_shape.push(self.attention_head_size);
|
||||
// Be cautious about the transposition if adding a batch dim!
|
||||
let xs = xs.reshape(new_x_shape.as_slice())?.transpose(0, 1)?;
|
||||
let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
|
||||
Ok(xs.contiguous()?)
|
||||
}
|
||||
|
||||
@ -370,7 +374,7 @@ impl BertSelfAttention {
|
||||
let attention_probs = self.dropout.forward(&attention_probs)?;
|
||||
|
||||
let context_layer = attention_probs.matmul(&value_layer)?;
|
||||
let context_layer = context_layer.transpose(0, 1)?.contiguous()?;
|
||||
let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
|
||||
let context_layer = context_layer.flatten(Some(context_layer.rank() - 2), None)?;
|
||||
Ok(context_layer)
|
||||
}
|
||||
@ -616,7 +620,7 @@ fn main() -> Result<()> {
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let token_ids = Tensor::new(&tokens[..], &device)?;
|
||||
let token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?;
|
||||
println!("{token_ids}");
|
||||
let token_type_ids = token_ids.zeros_like()?;
|
||||
let ys = model.forward(&token_ids, &token_type_ids)?;
|
||||
|
Reference in New Issue
Block a user