diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 58f42f38..bfa4e69a 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -1,7 +1,9 @@ #![allow(dead_code)] +// The tokenizer.json and weights should be retrieved from: +// https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2 -use anyhow::Error as E; -use candle::{safetensors::SafeTensors, DType, Device, Result, Shape, Tensor}; +use anyhow::{Error as E, Result}; +use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor}; use clap::Parser; use std::collections::HashMap; @@ -40,7 +42,7 @@ impl<'a> VarBuilder<'a> { } } - pub fn get>(&self, s: S, tensor_name: &str) -> Result { + pub fn get>(&self, s: S, tensor_name: &str) -> candle::Result { let s: Shape = s.into(); match &self.safetensors { None => Tensor::zeros(s, self.dtype, &self.device), @@ -71,7 +73,7 @@ enum HiddenAct { } impl HiddenAct { - fn forward(&self, xs: &Tensor) -> Result { + fn forward(&self, xs: &Tensor) -> candle::Result { match self { Self::Gelu => xs.gelu(), Self::Relu => xs.relu(), @@ -164,7 +166,8 @@ impl Embedding { } fn forward(&self, indexes: &Tensor) -> Result { - Tensor::embedding(indexes, &self.embeddings) + let values = Tensor::embedding(indexes, &self.embeddings)?; + Ok(values) } } @@ -281,11 +284,15 @@ impl BertEmbeddings { } fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result { + let seq_len = input_ids.shape().r1()?; let input_embeddings = self.word_embeddings.forward(input_ids)?; let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?; let mut embeddings = (input_embeddings + token_type_embeddings)?; if let Some(position_embeddings) = &self.position_embeddings { - embeddings = (&embeddings + position_embeddings.forward(&embeddings))? + // TODO: Proper absolute positions? + let position_ids = (0..seq_len as u32).collect::>(); + let position_ids = Tensor::new(&position_ids[..], &input_ids.device())?; + embeddings = (&embeddings + position_embeddings.forward(&position_ids)?)? } let embeddings = self.layer_norm.forward(&embeddings)?; let embeddings = self.dropout.forward(&embeddings)?; @@ -326,7 +333,8 @@ impl BertSelfAttention { new_x_shape.pop(); new_x_shape.push(self.num_attention_heads); new_x_shape.push(self.attention_head_size); - xs.reshape(new_x_shape.as_slice())?.transpose(1, 2) + let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?; + Ok(xs) } fn forward(&self, hidden_states: &Tensor) -> Result { @@ -425,7 +433,8 @@ impl BertIntermediate { fn forward(&self, hidden_states: &Tensor) -> Result { let hidden_states = self.dense.forward(hidden_states)?; - self.intermediate_act.forward(&hidden_states) + let ys = self.intermediate_act.forward(&hidden_states)?; + Ok(ys) } } @@ -534,8 +543,8 @@ impl BertModel { }) } - fn forward(&self, input_ids: &Tensor, position_ids: &Tensor) -> Result { - let embedding_output = self.embeddings.forward(input_ids, position_ids)?; + fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result { + let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?; let sequence_output = self.encoder.forward(&embedding_output)?; Ok(sequence_output) } @@ -555,7 +564,7 @@ struct Args { weights: String, } -fn main() -> anyhow::Result<()> { +fn main() -> Result<()> { use tokenizers::Tokenizer; let args = Args::parse(); @@ -579,9 +588,9 @@ fn main() -> anyhow::Result<()> { .get_ids() .to_vec(); let token_ids = Tensor::new(&tokens[..], &device)?; - let position_ids: Vec<_> = (0..tokens.len() as u32).collect(); - let position_ids = Tensor::new(&position_ids[..], &device)?.unsqueeze(0)?; - let ys = model.forward(&token_ids, &position_ids)?; + println!("{token_ids}"); + let token_type_ids = token_ids.zeros_like()?; + let ys = model.forward(&token_ids, &token_type_ids)?; println!("{ys}"); Ok(()) }