mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Minor tweaks.
This commit is contained in:
@ -594,7 +594,8 @@ fn main() -> Result<()> {
|
|||||||
Device::new_cuda(0)?
|
Device::new_cuda(0)?
|
||||||
};
|
};
|
||||||
|
|
||||||
let tokenizer = Tokenizer::from_file(args.tokenizer_config).map_err(E::msg)?;
|
let mut tokenizer = Tokenizer::from_file(args.tokenizer_config).map_err(E::msg)?;
|
||||||
|
let tokenizer = tokenizer.with_padding(None).with_truncation(None);
|
||||||
|
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };
|
let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };
|
||||||
let weights = weights.deserialize()?;
|
let weights = weights.deserialize()?;
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
#include "cuda_utils.cuh"
|
#include "cuda_utils.cuh"
|
||||||
|
#include<stdint.h>
|
||||||
|
|
||||||
#define UNARY_OP(TYPENAME, FN_NAME, FUNC) \
|
#define UNARY_OP(TYPENAME, FN_NAME, FUNC) \
|
||||||
extern "C" __global__ void FN_NAME( \
|
extern "C" __global__ void FN_NAME( \
|
||||||
@ -68,6 +69,8 @@ UNARY_OP(__half, ugelu_f16, gelu_fwd(x))
|
|||||||
UNARY_OP(__half, urelu_f16, relu_fwd(x))
|
UNARY_OP(__half, urelu_f16, relu_fwd(x))
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
UNARY_OP(uint8_t, ucopy_u8, x)
|
||||||
|
UNARY_OP(uint32_t, ucopy_u32, x)
|
||||||
UNARY_OP(float, ucopy_f32, x)
|
UNARY_OP(float, ucopy_f32, x)
|
||||||
UNARY_OP(double, ucopy_f64, x)
|
UNARY_OP(double, ucopy_f64, x)
|
||||||
UNARY_OP(float, uneg_f32, -x)
|
UNARY_OP(float, uneg_f32, -x)
|
||||||
|
Reference in New Issue
Block a user