mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
@ -153,20 +153,28 @@ impl Config {
|
||||
|
||||
struct Embedding {
|
||||
embeddings: Tensor,
|
||||
hidden_size: usize,
|
||||
}
|
||||
|
||||
impl Embedding {
|
||||
fn new(embeddings: Tensor) -> Self {
|
||||
Self { embeddings }
|
||||
fn new(embeddings: Tensor, hidden_size: usize) -> Self {
|
||||
Self {
|
||||
embeddings,
|
||||
hidden_size,
|
||||
}
|
||||
}
|
||||
|
||||
fn load(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
let embeddings = vb.get((size1, size2), &format!("{p}.weight"))?;
|
||||
Ok(Self::new(embeddings))
|
||||
fn load(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?;
|
||||
Ok(Self::new(embeddings, hidden_size))
|
||||
}
|
||||
|
||||
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
|
||||
let values = Tensor::embedding(indexes, &self.embeddings)?;
|
||||
let mut final_dims = indexes.dims().to_vec();
|
||||
final_dims.push(self.hidden_size);
|
||||
let indexes = indexes.flatten_all()?;
|
||||
let values = Tensor::embedding(&indexes, &self.embeddings)?;
|
||||
let values = values.reshape(final_dims)?;
|
||||
Ok(values)
|
||||
}
|
||||
}
|
||||
@ -594,7 +602,8 @@ fn main() -> Result<()> {
|
||||
Device::new_cuda(0)?
|
||||
};
|
||||
|
||||
let tokenizer = Tokenizer::from_file(args.tokenizer_config).map_err(E::msg)?;
|
||||
let mut tokenizer = Tokenizer::from_file(args.tokenizer_config).map_err(E::msg)?;
|
||||
let tokenizer = tokenizer.with_padding(None).with_truncation(None);
|
||||
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };
|
||||
let weights = weights.deserialize()?;
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include "cuda_utils.cuh"
|
||||
#include<stdint.h>
|
||||
|
||||
#define UNARY_OP(TYPENAME, FN_NAME, FUNC) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
@ -68,6 +69,8 @@ UNARY_OP(__half, ugelu_f16, gelu_fwd(x))
|
||||
UNARY_OP(__half, urelu_f16, relu_fwd(x))
|
||||
#endif
|
||||
|
||||
UNARY_OP(uint8_t, ucopy_u8, x)
|
||||
UNARY_OP(uint32_t, ucopy_u32, x)
|
||||
UNARY_OP(float, ucopy_f32, x)
|
||||
UNARY_OP(double, ucopy_f64, x)
|
||||
UNARY_OP(float, uneg_f32, -x)
|
||||
|
Reference in New Issue
Block a user