diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 3dfc5c8f..0871175f 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -227,14 +227,22 @@ impl Map2 for MatMul { let rhs_rs = rhs_stride[rank - 2]; let a_skip: usize = match lhs_stride[..rank - 2] { + [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride, [stride] => stride, [] => m * k, - _ => Err(Error::UnexpectedStriding)?, + _ => Err(Error::UnexpectedStriding { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + })?, }; let b_skip: usize = match rhs_stride[..rank - 2] { + [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride, [stride] => stride, [] => n * k, - _ => Err(Error::UnexpectedStriding)?, + _ => Err(Error::UnexpectedStriding { + lhs_stride: lhs_stride.to_vec(), + rhs_stride: rhs_stride.to_vec(), + })?, }; let c_skip: usize = m * n; diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 0b58787e..0c87004b 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -599,6 +599,7 @@ fn gemm_config( }; let stride_b: usize = match lhs_stride[..lhs_stride.len() - 2] { + [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride, [stride] => stride, [] => m * k, _ => Err(CudaError::MatMulNonContiguous { @@ -608,6 +609,7 @@ fn gemm_config( })?, }; let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] { + [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride, [stride] => stride, [] => n * k, _ => Err(CudaError::MatMulNonContiguous { diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index 2e82ab38..7a2d2984 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -93,8 +93,11 @@ pub enum Error { }, // TODO this is temporary when we support arbitrary matmul - #[error("temporary error where matmul doesn't support arbitrary striding")] - UnexpectedStriding, + #[error("temporary error where matmul doesn't support arbitrary striding {lhs_stride:?} x {rhs_stride:?}")] + UnexpectedStriding { + lhs_stride: Vec, + rhs_stride: Vec, + }, #[error(transparent)] Cuda(#[from] crate::CudaError), diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 2bd1fb1d..e5801314 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -75,6 +75,9 @@ enum HiddenAct { impl HiddenAct { fn forward(&self, xs: &Tensor) -> candle::Result { match self { + // TODO: The all-MiniLM-L6-v2 model uses "gelu" whereas this is "gelu_new", this explains some + // small numerical difference. + // https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213 Self::Gelu => xs.gelu(), Self::Relu => xs.relu(), } @@ -196,7 +199,9 @@ impl Linear { } fn forward(&self, x: &Tensor) -> Result { - let x = x.matmul(&self.weight.t()?)?; + let (bsize, _, _) = x.shape().r3()?; + let w = self.weight.broadcast_left(bsize)?.t()?; + let x = x.matmul(&w)?; let x = x.broadcast_add(&self.bias)?; Ok(x) } @@ -236,12 +241,11 @@ impl LayerNorm { } fn forward(&self, x: &Tensor) -> Result { - let (seq_len, hidden_size) = x.shape().r2()?; - let mean_x = (x.sum(&[1])? / hidden_size as f64)?; + let (_bsize, _seq_len, hidden_size) = x.shape().r3()?; + let mean_x = (x.sum(&[2])? / hidden_size as f64)?; let x = x.broadcast_sub(&mean_x)?; - let norm_x = ((&x * &x)?.sum(&[1])? / hidden_size as f64)?; - let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?; - let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?; + let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?; + let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; let x = x_normed .broadcast_mul(&self.weight)? .broadcast_add(&self.bias)?; @@ -301,7 +305,7 @@ impl BertEmbeddings { } fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result { - let seq_len = input_ids.shape().r1()?; + let (_bsize, seq_len) = input_ids.shape().r2()?; let input_embeddings = self.word_embeddings.forward(input_ids)?; let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?; let mut embeddings = (&input_embeddings + token_type_embeddings)?; @@ -309,7 +313,7 @@ impl BertEmbeddings { // TODO: Proper absolute positions? let position_ids = (0..seq_len as u32).collect::>(); let position_ids = Tensor::new(&position_ids[..], &input_ids.device())?; - embeddings = (&embeddings + position_embeddings.forward(&position_ids)?)? + embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)? } let embeddings = self.layer_norm.forward(&embeddings)?; let embeddings = self.dropout.forward(&embeddings)?; @@ -351,7 +355,7 @@ impl BertSelfAttention { new_x_shape.push(self.num_attention_heads); new_x_shape.push(self.attention_head_size); // Be cautious about the transposition if adding a batch dim! - let xs = xs.reshape(new_x_shape.as_slice())?.transpose(0, 1)?; + let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?; Ok(xs.contiguous()?) } @@ -370,7 +374,7 @@ impl BertSelfAttention { let attention_probs = self.dropout.forward(&attention_probs)?; let context_layer = attention_probs.matmul(&value_layer)?; - let context_layer = context_layer.transpose(0, 1)?.contiguous()?; + let context_layer = context_layer.transpose(1, 2)?.contiguous()?; let context_layer = context_layer.flatten(Some(context_layer.rank() - 2), None)?; Ok(context_layer) } @@ -616,7 +620,7 @@ fn main() -> Result<()> { .map_err(E::msg)? .get_ids() .to_vec(); - let token_ids = Tensor::new(&tokens[..], &device)?; + let token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?; println!("{token_ids}"); let token_type_ids = token_ids.zeros_like()?; let ys = model.forward(&token_ids, &token_type_ids)?;