From 9c61b0fc9b9062b347c176b5f0f86b97b6804a1b Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 3 Sep 2023 21:33:50 +0200 Subject: [PATCH] Proper log buckets for t5. (#727) * Proper log buckets for t5. * Properly pass the position bias. --- candle-examples/examples/musicgen/t5_model.rs | 111 ++++++++++++------ 1 file changed, 74 insertions(+), 37 deletions(-) diff --git a/candle-examples/examples/musicgen/t5_model.rs b/candle-examples/examples/musicgen/t5_model.rs index 607b5c93..22f0a4f5 100644 --- a/candle-examples/examples/musicgen/t5_model.rs +++ b/candle-examples/examples/musicgen/t5_model.rs @@ -202,7 +202,11 @@ impl T5Attention { }) } - fn forward(&self, xs: &Tensor) -> Result { + fn forward( + &self, + xs: &Tensor, + position_bias: Option<&Tensor>, + ) -> Result<(Tensor, Option)> { // TODO: Apply the mask(s)? // TODO: kv caching. let (b_sz, seq_len) = (xs.dim(0)?, xs.dim(1)?); @@ -223,34 +227,57 @@ impl T5Attention { .contiguous()?; let scores = q.matmul(&k.t()?)?; - let scores = match &self.relative_attention_bias { - None => scores, - Some(relative_attention_bias) => { - let query_length = seq_len; - let key_length = seq_len; - // This only handles the bidirectional case. - let num_buckets = self.relative_attention_num_buckets / 2; - let relative_position = (0..query_length as u32) - .map(|i| { - (0..key_length as u32) - .map(|j| { - if i < j { - j - i + num_buckets as u32 - } else { - i - j - } - }) - .collect::>() - }) - .collect::>>(); - let relative_buckets = Tensor::new(relative_position, q.device())?; - let position_bias = relative_attention_bias - .forward(&relative_buckets)? - .permute((2, 0, 1))? - .unsqueeze(0)?; - (scores + position_bias)? - // TODO: position_bias_masked? - } + let (scores, position_bias) = match position_bias { + Some(position_bias) => ((scores + position_bias)?, Some(position_bias.clone())), + None => match &self.relative_attention_bias { + None => (scores, None), + Some(relative_attention_bias) => { + let query_length = seq_len; + let key_length = seq_len; + // This only handles the bidirectional case. + let num_buckets = self.relative_attention_num_buckets as u32 / 2; + let max_exact = num_buckets / 2; + let relative_position = (0..query_length as u32) + .map(|i| { + (0..key_length as u32) + .map(|j| { + if i < j { + if j - i < max_exact { + j - i + num_buckets + } else { + let b = f32::log( + (j - i) as f32 / max_exact as f32, + self.relative_attention_max_distance as f32 + / max_exact as f32, + ) * (num_buckets - max_exact) as f32; + u32::min( + max_exact + num_buckets + b as u32, + self.relative_attention_num_buckets as u32 - 1, + ) + } + } else if i - j < max_exact { + i - j + } else { + let b = f32::log( + (i - j) as f32 / max_exact as f32, + self.relative_attention_max_distance as f32 + / max_exact as f32, + ) * (num_buckets - max_exact) as f32; + max_exact + b as u32 + } + }) + .collect::>() + }) + .collect::>>(); + let relative_buckets = Tensor::new(relative_position, q.device())?; + let position_bias = relative_attention_bias + .forward(&relative_buckets)? + .permute((2, 0, 1))? + .unsqueeze(0)?; + ((scores + &position_bias)?, Some(position_bias)) + // TODO: position_bias_masked? + } + }, }; let attn_weights = candle_nn::ops::softmax(&scores, D::Minus1)?; @@ -259,7 +286,7 @@ impl T5Attention { .transpose(1, 2)? .reshape((b_sz, seq_len, self.inner_dim))?; let attn_output = self.o.forward(&attn_output)?; - Ok(attn_output) + Ok((attn_output, position_bias)) } } @@ -280,11 +307,15 @@ impl T5LayerSelfAttention { }) } - fn forward(&self, xs: &Tensor) -> Result { + fn forward( + &self, + xs: &Tensor, + position_bias: Option<&Tensor>, + ) -> Result<(Tensor, Option)> { let normed_xs = self.layer_norm.forward(xs)?; - let ys = self.self_attention.forward(&normed_xs)?; + let (ys, position_bias) = self.self_attention.forward(&normed_xs, position_bias)?; let ys = (xs + ys)?; - Ok(ys) + Ok((ys, position_bias)) } } @@ -326,8 +357,12 @@ impl T5Block { }) } - fn forward(&self, xs: &Tensor) -> Result { - let mut xs = self.self_attn.forward(xs)?; + fn forward( + &self, + xs: &Tensor, + position_bias: Option<&Tensor>, + ) -> Result<(Tensor, Option)> { + let (mut xs, position_bias) = self.self_attn.forward(xs, position_bias)?; // TODO: clamp for f16? if let Some(cross_attn) = &self.cross_attn { xs = cross_attn.forward(&xs)?; @@ -335,7 +370,7 @@ impl T5Block { } let xs = self.ff.forward(&xs)?; // TODO: clamp for f16? - Ok(xs) + Ok((xs, position_bias)) } } @@ -368,8 +403,10 @@ impl T5Stack { let (_b_sz, _seq_len) = (input_embeds.dim(0)?, input_embeds.dim(1)?); let mut hidden_states = input_embeds; + let mut position_bias = None; for block in self.block.iter() { - hidden_states = block.forward(&hidden_states)? + (hidden_states, position_bias) = + block.forward(&hidden_states, position_bias.as_ref())? } let hidden_states = self.final_layer_norm.forward(&hidden_states)?; Ok(hidden_states)