From 61ddb9535ee1d5c0ef2b5bd298f1959d328c02db Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 26 May 2025 08:54:31 +0200 Subject: [PATCH] Use a tanh activation in the xlm-roberta classification head. (#2968) --- candle-transformers/src/models/xlm_roberta.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/candle-transformers/src/models/xlm_roberta.rs b/candle-transformers/src/models/xlm_roberta.rs index 96e763e1..6fb1268a 100644 --- a/candle-transformers/src/models/xlm_roberta.rs +++ b/candle-transformers/src/models/xlm_roberta.rs @@ -482,8 +482,10 @@ impl XLMRobertaClassificationHead { fn forward(&self, hidden_states: &Tensor) -> Result { let cls_states = hidden_states.get_on_dim(1, 0)?.contiguous()?; let hidden_states = self.dense.forward(&cls_states)?; - let hidden_states = candle_nn::Activation::GeluPytorchTanh.forward(&hidden_states)?; - let hidden_states = self.out_proj.forward(&hidden_states)?; + // The activation used in the classification head is tanh, as per the original + // implementation. + // https://github.com/huggingface/transformers/blob/6e3063422c4b1c014aa60c32b9254fd2902f0f28/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py#L1454 + let hidden_states = self.out_proj.forward(&hidden_states.tanh()?)?; Ok(hidden_states) } }