mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add a batch dimension on the bert example.
This commit is contained in:
@ -227,14 +227,22 @@ impl Map2 for MatMul {
|
|||||||
let rhs_rs = rhs_stride[rank - 2];
|
let rhs_rs = rhs_stride[rank - 2];
|
||||||
|
|
||||||
let a_skip: usize = match lhs_stride[..rank - 2] {
|
let a_skip: usize = match lhs_stride[..rank - 2] {
|
||||||
|
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||||
[stride] => stride,
|
[stride] => stride,
|
||||||
[] => m * k,
|
[] => m * k,
|
||||||
_ => Err(Error::UnexpectedStriding)?,
|
_ => Err(Error::UnexpectedStriding {
|
||||||
|
lhs_stride: lhs_stride.to_vec(),
|
||||||
|
rhs_stride: rhs_stride.to_vec(),
|
||||||
|
})?,
|
||||||
};
|
};
|
||||||
let b_skip: usize = match rhs_stride[..rank - 2] {
|
let b_skip: usize = match rhs_stride[..rank - 2] {
|
||||||
|
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||||
[stride] => stride,
|
[stride] => stride,
|
||||||
[] => n * k,
|
[] => n * k,
|
||||||
_ => Err(Error::UnexpectedStriding)?,
|
_ => Err(Error::UnexpectedStriding {
|
||||||
|
lhs_stride: lhs_stride.to_vec(),
|
||||||
|
rhs_stride: rhs_stride.to_vec(),
|
||||||
|
})?,
|
||||||
};
|
};
|
||||||
let c_skip: usize = m * n;
|
let c_skip: usize = m * n;
|
||||||
|
|
||||||
|
@ -599,6 +599,7 @@ fn gemm_config<T>(
|
|||||||
};
|
};
|
||||||
|
|
||||||
let stride_b: usize = match lhs_stride[..lhs_stride.len() - 2] {
|
let stride_b: usize = match lhs_stride[..lhs_stride.len() - 2] {
|
||||||
|
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||||
[stride] => stride,
|
[stride] => stride,
|
||||||
[] => m * k,
|
[] => m * k,
|
||||||
_ => Err(CudaError::MatMulNonContiguous {
|
_ => Err(CudaError::MatMulNonContiguous {
|
||||||
@ -608,6 +609,7 @@ fn gemm_config<T>(
|
|||||||
})?,
|
})?,
|
||||||
};
|
};
|
||||||
let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] {
|
let stride_a: usize = match rhs_stride[..rhs_stride.len() - 2] {
|
||||||
|
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||||
[stride] => stride,
|
[stride] => stride,
|
||||||
[] => n * k,
|
[] => n * k,
|
||||||
_ => Err(CudaError::MatMulNonContiguous {
|
_ => Err(CudaError::MatMulNonContiguous {
|
||||||
|
@ -93,8 +93,11 @@ pub enum Error {
|
|||||||
},
|
},
|
||||||
|
|
||||||
// TODO this is temporary when we support arbitrary matmul
|
// TODO this is temporary when we support arbitrary matmul
|
||||||
#[error("temporary error where matmul doesn't support arbitrary striding")]
|
#[error("temporary error where matmul doesn't support arbitrary striding {lhs_stride:?} x {rhs_stride:?}")]
|
||||||
UnexpectedStriding,
|
UnexpectedStriding {
|
||||||
|
lhs_stride: Vec<usize>,
|
||||||
|
rhs_stride: Vec<usize>,
|
||||||
|
},
|
||||||
|
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Cuda(#[from] crate::CudaError),
|
Cuda(#[from] crate::CudaError),
|
||||||
|
@ -75,6 +75,9 @@ enum HiddenAct {
|
|||||||
impl HiddenAct {
|
impl HiddenAct {
|
||||||
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
|
||||||
match self {
|
match self {
|
||||||
|
// TODO: The all-MiniLM-L6-v2 model uses "gelu" whereas this is "gelu_new", this explains some
|
||||||
|
// small numerical difference.
|
||||||
|
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
|
||||||
Self::Gelu => xs.gelu(),
|
Self::Gelu => xs.gelu(),
|
||||||
Self::Relu => xs.relu(),
|
Self::Relu => xs.relu(),
|
||||||
}
|
}
|
||||||
@ -196,7 +199,9 @@ impl Linear {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
let x = x.matmul(&self.weight.t()?)?;
|
let (bsize, _, _) = x.shape().r3()?;
|
||||||
|
let w = self.weight.broadcast_left(bsize)?.t()?;
|
||||||
|
let x = x.matmul(&w)?;
|
||||||
let x = x.broadcast_add(&self.bias)?;
|
let x = x.broadcast_add(&self.bias)?;
|
||||||
Ok(x)
|
Ok(x)
|
||||||
}
|
}
|
||||||
@ -236,12 +241,11 @@ impl LayerNorm {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
let (seq_len, hidden_size) = x.shape().r2()?;
|
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
|
||||||
let mean_x = (x.sum(&[1])? / hidden_size as f64)?;
|
let mean_x = (x.sum(&[2])? / hidden_size as f64)?;
|
||||||
let x = x.broadcast_sub(&mean_x)?;
|
let x = x.broadcast_sub(&mean_x)?;
|
||||||
let norm_x = ((&x * &x)?.sum(&[1])? / hidden_size as f64)?;
|
let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?;
|
||||||
let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
|
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||||
let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
|
|
||||||
let x = x_normed
|
let x = x_normed
|
||||||
.broadcast_mul(&self.weight)?
|
.broadcast_mul(&self.weight)?
|
||||||
.broadcast_add(&self.bias)?;
|
.broadcast_add(&self.bias)?;
|
||||||
@ -301,7 +305,7 @@ impl BertEmbeddings {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
|
fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
|
||||||
let seq_len = input_ids.shape().r1()?;
|
let (_bsize, seq_len) = input_ids.shape().r2()?;
|
||||||
let input_embeddings = self.word_embeddings.forward(input_ids)?;
|
let input_embeddings = self.word_embeddings.forward(input_ids)?;
|
||||||
let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
|
let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
|
||||||
let mut embeddings = (&input_embeddings + token_type_embeddings)?;
|
let mut embeddings = (&input_embeddings + token_type_embeddings)?;
|
||||||
@ -309,7 +313,7 @@ impl BertEmbeddings {
|
|||||||
// TODO: Proper absolute positions?
|
// TODO: Proper absolute positions?
|
||||||
let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
|
let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
|
||||||
let position_ids = Tensor::new(&position_ids[..], &input_ids.device())?;
|
let position_ids = Tensor::new(&position_ids[..], &input_ids.device())?;
|
||||||
embeddings = (&embeddings + position_embeddings.forward(&position_ids)?)?
|
embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)?
|
||||||
}
|
}
|
||||||
let embeddings = self.layer_norm.forward(&embeddings)?;
|
let embeddings = self.layer_norm.forward(&embeddings)?;
|
||||||
let embeddings = self.dropout.forward(&embeddings)?;
|
let embeddings = self.dropout.forward(&embeddings)?;
|
||||||
@ -351,7 +355,7 @@ impl BertSelfAttention {
|
|||||||
new_x_shape.push(self.num_attention_heads);
|
new_x_shape.push(self.num_attention_heads);
|
||||||
new_x_shape.push(self.attention_head_size);
|
new_x_shape.push(self.attention_head_size);
|
||||||
// Be cautious about the transposition if adding a batch dim!
|
// Be cautious about the transposition if adding a batch dim!
|
||||||
let xs = xs.reshape(new_x_shape.as_slice())?.transpose(0, 1)?;
|
let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
|
||||||
Ok(xs.contiguous()?)
|
Ok(xs.contiguous()?)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -370,7 +374,7 @@ impl BertSelfAttention {
|
|||||||
let attention_probs = self.dropout.forward(&attention_probs)?;
|
let attention_probs = self.dropout.forward(&attention_probs)?;
|
||||||
|
|
||||||
let context_layer = attention_probs.matmul(&value_layer)?;
|
let context_layer = attention_probs.matmul(&value_layer)?;
|
||||||
let context_layer = context_layer.transpose(0, 1)?.contiguous()?;
|
let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
|
||||||
let context_layer = context_layer.flatten(Some(context_layer.rank() - 2), None)?;
|
let context_layer = context_layer.flatten(Some(context_layer.rank() - 2), None)?;
|
||||||
Ok(context_layer)
|
Ok(context_layer)
|
||||||
}
|
}
|
||||||
@ -616,7 +620,7 @@ fn main() -> Result<()> {
|
|||||||
.map_err(E::msg)?
|
.map_err(E::msg)?
|
||||||
.get_ids()
|
.get_ids()
|
||||||
.to_vec();
|
.to_vec();
|
||||||
let token_ids = Tensor::new(&tokens[..], &device)?;
|
let token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?;
|
||||||
println!("{token_ids}");
|
println!("{token_ids}");
|
||||||
let token_type_ids = token_ids.zeros_like()?;
|
let token_type_ids = token_ids.zeros_like()?;
|
||||||
let ys = model.forward(&token_ids, &token_type_ids)?;
|
let ys = model.forward(&token_ids, &token_type_ids)?;
|
||||||
|
Reference in New Issue
Block a user