Add a simpler way to specify the dim index for some ops.

This commit is contained in:
laurent
2023-07-05 20:22:43 +01:00
parent b7388bbf71
commit 2c3d871b2e
7 changed files with 93 additions and 34 deletions

View File

@ -386,12 +386,12 @@ impl BertSelfAttention {
let attention_scores = query_layer.matmul(&key_layer.t()?)?;
let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
let attention_probs = attention_scores.softmax(attention_scores.rank() - 1)?;
let attention_probs = attention_scores.softmax(candle::D::Minus1)?;
let attention_probs = self.dropout.forward(&attention_probs)?;
let context_layer = attention_probs.matmul(&value_layer)?;
let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
let context_layer = context_layer.flatten(Some(context_layer.rank() - 2), None)?;
let context_layer = context_layer.flatten_from(candle::D::Minus2)?;
Ok(context_layer)
}
}

View File

@ -283,19 +283,18 @@ impl CausalSelfAttention {
dims.push(v / 2);
dims.push(2);
let x = x.reshape(dims)?;
let rank = x.rank();
let re_x = x.narrow(rank - 1, 0, 1)?;
let im_x = x.narrow(rank - 1, 1, 1)?;
let re_x = x.narrow(candle::D::Minus1, 0, 1)?;
let im_x = x.narrow(candle::D::Minus1, 1, 1)?;
let re_f = freqs_cis
.narrow(rank - 1, 0, 1)?
.narrow(candle::D::Minus1, 0, 1)?
.broadcast_as(re_x.shape())?;
let im_f = freqs_cis
.narrow(rank - 1, 1, 1)?
.narrow(candle::D::Minus1, 1, 1)?
.broadcast_as(im_x.shape())?;
let re = ((&re_x * &re_f)? - (&im_x * &im_f)?)?;
let im = ((&re_x * &im_f)? + (&im_x * &re_f)?)?;
let rope = Tensor::cat(&[&re, &im], rank - 1)?;
let rope = rope.flatten(Some(rope.rank() - 2), None)?;
let rope = Tensor::cat(&[&re, &im], re.rank() - 1)?;
let rope = rope.flatten_from(candle::D::Minus2)?;
Ok(rope)
}
@ -339,7 +338,7 @@ impl CausalSelfAttention {
let att = (q.matmul(&k.t()?)? / (*k_shape.dims().last().unwrap() as f64).sqrt())?;
let mask = self.cache.mask(t)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = att.softmax(att.rank() - 1)?;
let att = att.softmax(candle::D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;
let y = y.transpose(0, 1)?.reshape(&[t, c])?;
@ -537,7 +536,7 @@ async fn main() -> Result<()> {
let next_token = if let Some(temperature) = args.temperature {
println!("Sampling with temperature {temperature:?}");
let prs = (&logits / temperature)?.softmax(logits.rank() - 1)?;
let prs = (&logits / temperature)?.softmax(candle::D::Minus1)?;
let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;

View File

@ -109,7 +109,7 @@ impl Decode {
};
tokens.push(next_token);
let prob = logits
.softmax(logits.rank() - 1)?
.softmax(candle::D::Minus1)?
.get(next_token as usize)?
.to_scalar::<f32>()? as f64;
if next_token == EOT_TOKEN || tokens.len() > model.config.n_text_ctx {

View File

@ -342,8 +342,8 @@ impl MultiHeadAttention {
let mask = mask.narrow(0, 0, n_ctx)?.narrow(1, 0, n_ctx)?;
qk = qk.broadcast_add(&mask)?
}
let w = qk.softmax(qk.rank() - 1)?;
let wv = w.matmul(&v)?.transpose(1, 2)?.flatten(Some(2), None)?;
let w = qk.softmax(candle::D::Minus1)?;
let wv = w.matmul(&v)?.transpose(1, 2)?.flatten_from(2)?;
Ok(wv)
}
}