From 9784d1ed9fbb7693d53d4e795435630a2ebcace7 Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 3 Jul 2023 18:31:55 +0100 Subject: [PATCH 1/2] Minor tweaks. --- candle-examples/examples/bert/main.rs | 3 ++- candle-kernels/src/unary.cu | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 5fed800d..6b990901 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -594,7 +594,8 @@ fn main() -> Result<()> { Device::new_cuda(0)? }; - let tokenizer = Tokenizer::from_file(args.tokenizer_config).map_err(E::msg)?; + let mut tokenizer = Tokenizer::from_file(args.tokenizer_config).map_err(E::msg)?; + let tokenizer = tokenizer.with_padding(None).with_truncation(None); let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? }; let weights = weights.deserialize()?; diff --git a/candle-kernels/src/unary.cu b/candle-kernels/src/unary.cu index c4df7893..726db339 100644 --- a/candle-kernels/src/unary.cu +++ b/candle-kernels/src/unary.cu @@ -1,4 +1,5 @@ #include "cuda_utils.cuh" +#include #define UNARY_OP(TYPENAME, FN_NAME, FUNC) \ extern "C" __global__ void FN_NAME( \ @@ -68,6 +69,8 @@ UNARY_OP(__half, ugelu_f16, gelu_fwd(x)) UNARY_OP(__half, urelu_f16, relu_fwd(x)) #endif +UNARY_OP(uint8_t, ucopy_u8, x) +UNARY_OP(uint32_t, ucopy_u32, x) UNARY_OP(float, ucopy_f32, x) UNARY_OP(double, ucopy_f64, x) UNARY_OP(float, uneg_f32, -x) From b6d179cc1c91e05fd2bbb39c9f426505a242bad7 Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 3 Jul 2023 18:37:40 +0100 Subject: [PATCH 2/2] Allow for batch dimensions in the embedding layer. --- candle-examples/examples/bert/main.rs | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 6b990901..2bd1fb1d 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -153,20 +153,28 @@ impl Config { struct Embedding { embeddings: Tensor, + hidden_size: usize, } impl Embedding { - fn new(embeddings: Tensor) -> Self { - Self { embeddings } + fn new(embeddings: Tensor, hidden_size: usize) -> Self { + Self { + embeddings, + hidden_size, + } } - fn load(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result { - let embeddings = vb.get((size1, size2), &format!("{p}.weight"))?; - Ok(Self::new(embeddings)) + fn load(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result { + let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?; + Ok(Self::new(embeddings, hidden_size)) } fn forward(&self, indexes: &Tensor) -> Result { - let values = Tensor::embedding(indexes, &self.embeddings)?; + let mut final_dims = indexes.dims().to_vec(); + final_dims.push(self.hidden_size); + let indexes = indexes.flatten_all()?; + let values = Tensor::embedding(&indexes, &self.embeddings)?; + let values = values.reshape(final_dims)?; Ok(values) } }