diff --git a/candle-transformers/src/models/xlm_roberta.rs b/candle-transformers/src/models/xlm_roberta.rs index 96e763e1..6fb1268a 100644 --- a/candle-transformers/src/models/xlm_roberta.rs +++ b/candle-transformers/src/models/xlm_roberta.rs @@ -482,8 +482,10 @@ impl XLMRobertaClassificationHead { fn forward(&self, hidden_states: &Tensor) -> Result { let cls_states = hidden_states.get_on_dim(1, 0)?.contiguous()?; let hidden_states = self.dense.forward(&cls_states)?; - let hidden_states = candle_nn::Activation::GeluPytorchTanh.forward(&hidden_states)?; - let hidden_states = self.out_proj.forward(&hidden_states)?; + // The activation used in the classification head is tanh, as per the original + // implementation. + // https://github.com/huggingface/transformers/blob/6e3063422c4b1c014aa60c32b9254fd2902f0f28/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py#L1454 + let hidden_states = self.out_proj.forward(&hidden_states.tanh()?)?; Ok(hidden_states) } }