diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 88e29718..cb80f6eb 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -126,7 +126,7 @@ fn main() -> Result<()> { println!("Loaded and encoded {:?}", start.elapsed()); for idx in 0..args.n { let start = std::time::Instant::now(); - let ys = model.forward(&token_ids, &token_type_ids)?; + let ys = model.forward(&token_ids, &token_type_ids, None)?; if idx == 0 { println!("{ys}"); } @@ -163,11 +163,19 @@ fn main() -> Result<()> { Ok(Tensor::new(tokens.as_slice(), device)?) }) .collect::>>()?; + let attention_mask = tokens + .iter() + .map(|tokens| { + let tokens = tokens.get_attention_mask().to_vec(); + Ok(Tensor::new(tokens.as_slice(), device)?) + }) + .collect::>>()?; let token_ids = Tensor::stack(&token_ids, 0)?; + let attention_mask = Tensor::stack(&attention_mask, 0)?; let token_type_ids = token_ids.zeros_like()?; println!("running inference on batch {:?}", token_ids.shape()); - let embeddings = model.forward(&token_ids, &token_type_ids)?; + let embeddings = model.forward(&token_ids, &token_type_ids, Some(&attention_mask))?; println!("generated embeddings {:?}", embeddings.shape()); // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?; diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs index 810f2803..42486a2d 100644 --- a/candle-transformers/src/models/bert.rs +++ b/candle-transformers/src/models/bert.rs @@ -230,10 +230,8 @@ impl BertSelfAttention { let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?; xs.contiguous() } -} -impl Module for BertSelfAttention { - fn forward(&self, hidden_states: &Tensor) -> Result { + fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result { let _enter = self.span.enter(); let query_layer = self.query.forward(hidden_states)?; let key_layer = self.key.forward(hidden_states)?; @@ -245,6 +243,7 @@ impl Module for BertSelfAttention { let attention_scores = query_layer.matmul(&key_layer.t()?)?; let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?; + let attention_scores = attention_scores.broadcast_add(attention_mask)?; let attention_probs = { let _enter_sm = self.span_softmax.enter(); candle_nn::ops::softmax(&attention_scores, candle::D::Minus1)? @@ -307,12 +306,10 @@ impl BertAttention { span: tracing::span!(tracing::Level::TRACE, "attn"), }) } -} -impl Module for BertAttention { - fn forward(&self, hidden_states: &Tensor) -> Result { + fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result { let _enter = self.span.enter(); - let self_outputs = self.self_attention.forward(hidden_states)?; + let self_outputs = self.self_attention.forward(hidden_states, attention_mask)?; let attention_output = self.self_output.forward(&self_outputs, hidden_states)?; Ok(attention_output) } @@ -398,12 +395,10 @@ impl BertLayer { span: tracing::span!(tracing::Level::TRACE, "layer"), }) } -} -impl Module for BertLayer { - fn forward(&self, hidden_states: &Tensor) -> Result { + fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result { let _enter = self.span.enter(); - let attention_output = self.attention.forward(hidden_states)?; + let attention_output = self.attention.forward(hidden_states, attention_mask)?; // TODO: Support cross-attention? // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523 // TODO: Support something similar to `apply_chunking_to_forward`? @@ -429,15 +424,13 @@ impl BertEncoder { let span = tracing::span!(tracing::Level::TRACE, "encoder"); Ok(BertEncoder { layers, span }) } -} -impl Module for BertEncoder { - fn forward(&self, hidden_states: &Tensor) -> Result { + fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result { let _enter = self.span.enter(); let mut hidden_states = hidden_states.clone(); // Use a loop rather than a fold as it's easier to modify when adding debug/... for layer in self.layers.iter() { - hidden_states = layer.forward(&hidden_states)? + hidden_states = layer.forward(&hidden_states, attention_mask)? } Ok(hidden_states) } @@ -481,10 +474,32 @@ impl BertModel { }) } - pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result { + pub fn forward( + &self, + input_ids: &Tensor, + token_type_ids: &Tensor, + attention_mask: Option<&Tensor>, + ) -> Result { let _enter = self.span.enter(); let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?; - let sequence_output = self.encoder.forward(&embedding_output)?; + let attention_mask = match attention_mask { + Some(attention_mask) => attention_mask.clone(), + None => input_ids.ones_like()?, + }; + // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995 + let attention_mask = get_extended_attention_mask(&attention_mask, DType::F32)?; + let sequence_output = self.encoder.forward(&embedding_output, &attention_mask)?; Ok(sequence_output) } } + +fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result { + let attention_mask = match attention_mask.rank() { + 3 => attention_mask.unsqueeze(1)?, + 2 => attention_mask.unsqueeze(1)?.unsqueeze(1)?, + _ => candle::bail!("Wrong shape for input_ids or attention_mask"), + }; + let attention_mask = attention_mask.to_dtype(dtype)?; + // torch.finfo(dtype).min + (attention_mask.ones_like()? - attention_mask)?.broadcast_mul(&Tensor::try_from(f32::MIN)?) +} diff --git a/candle-wasm-examples/bert/src/bin/m.rs b/candle-wasm-examples/bert/src/bin/m.rs index 92617f15..9e5cf913 100644 --- a/candle-wasm-examples/bert/src/bin/m.rs +++ b/candle-wasm-examples/bert/src/bin/m.rs @@ -55,11 +55,21 @@ impl Model { Tensor::new(tokens.as_slice(), device) }) .collect::, _>>()?; + let attention_mask: Vec = tokens + .iter() + .map(|tokens| { + let tokens = tokens.get_attention_mask().to_vec(); + Tensor::new(tokens.as_slice(), device) + }) + .collect::, _>>()?; let token_ids = Tensor::stack(&token_ids, 0)?; + let attention_mask = Tensor::stack(&attention_mask, 0)?; let token_type_ids = token_ids.zeros_like()?; console_log!("running inference on batch {:?}", token_ids.shape()); - let embeddings = self.bert.forward(&token_ids, &token_type_ids)?; + let embeddings = self + .bert + .forward(&token_ids, &token_type_ids, Some(&attention_mask))?; console_log!("generated embeddings {:?}", embeddings.shape()); // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;