Compare commits

...

4 Commits

4 changed files with 115 additions and 26 deletions

View File

@ -43,6 +43,18 @@ enum Which {
Solar10_7B,
#[value(name = "tiny-llama-1.1b-chat")]
TinyLlama1_1BChat,
#[value(name = "SmoLM2-1B")]
SmolLM2_1B,
#[value(name = "SmoLM2-1B-Instruct")]
SmolLM2_1BInstruct,
#[value(name = "SmoLM2-360M")]
SmolLM2_360M,
#[value(name = "SmoLM2-360M-Instruct")]
SmolLM2_360MInstruct,
#[value(name = "SmoLM2-135M")]
SmolLM2_135M,
#[value(name = "SmoLM2-135M-Instruct")]
SmolLM2_135MInstruct,
}
#[derive(Parser, Debug)]
@ -134,19 +146,28 @@ fn main() -> Result<()> {
};
let (llama, tokenizer_filename, mut cache, config) = {
let api = Api::new()?;
let model_id = args.model_id.unwrap_or_else(|| match args.which {
Which::V1 => "Narsil/amall-7b".to_string(),
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(),
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(),
Which::V31 => "meta-llama/Llama-3.1-8B".to_string(),
Which::V31Instruct => "meta-llama/Llama-3.1-8B-Instruct".to_string(),
Which::V32_1b => "meta-llama/Llama-3.2-1B".to_string(),
Which::V32_1bInstruct => "meta-llama/Llama-3.2-1B-Instruct".to_string(),
Which::V32_3b => "meta-llama/Llama-3.2-3B".to_string(),
Which::V32_3bInstruct => "meta-llama/Llama-3.2-3B-Instruct".to_string(),
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
let model_id = args.model_id.unwrap_or_else(|| {
let str = match args.which {
Which::V1 => "Narsil/amall-7b",
Which::V2 => "meta-llama/Llama-2-7b-hf",
Which::V3 => "meta-llama/Meta-Llama-3-8B",
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct",
Which::V31 => "meta-llama/Llama-3.1-8B",
Which::V31Instruct => "meta-llama/Llama-3.1-8B-Instruct",
Which::V32_1b => "meta-llama/Llama-3.2-1B",
Which::V32_1bInstruct => "meta-llama/Llama-3.2-1B-Instruct",
Which::V32_3b => "meta-llama/Llama-3.2-3B",
Which::V32_3bInstruct => "meta-llama/Llama-3.2-3B-Instruct",
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0",
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
Which::SmolLM2_135M => "HuggingFaceTB/SmolLM2-135M",
Which::SmolLM2_135MInstruct => "HuggingFaceTB/SmolLM2-135M-Instruct",
Which::SmolLM2_360M => "HuggingFaceTB/SmolLM2-360M",
Which::SmolLM2_360MInstruct => "HuggingFaceTB/SmolLM2-360M-Instruct",
Which::SmolLM2_1B => "HuggingFaceTB/SmolLM2-1.7B",
Which::SmolLM2_1BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
};
str.to_string()
});
println!("loading the model weights from {model_id}");
let revision = args.revision.unwrap_or("main".to_string());
@ -169,7 +190,15 @@ fn main() -> Result<()> {
| Which::Solar10_7B => {
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
}
Which::V32_1b | Which::V32_1bInstruct | Which::TinyLlama1_1BChat => {
Which::SmolLM2_360M
| Which::SmolLM2_360MInstruct
| Which::SmolLM2_135M
| Which::SmolLM2_135MInstruct
| Which::SmolLM2_1B
| Which::SmolLM2_1BInstruct
| Which::V32_1b
| Which::V32_1bInstruct
| Which::TinyLlama1_1BChat => {
vec![api.get("model.safetensors")?]
}
};

View File

@ -70,10 +70,9 @@ static __device__ __forceinline__ float warp_reduce_sum(float x) {
// LayerNorm implementation adapted from ggml, accumulation is made using f32.
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L477
template <typename T>
__device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta, const int ncols, const float eps) {
__device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta, const int ncols, const int block_size, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
const int block_size = blockDim.x;
float2 mean_var = make_float2(0.f, 0.f);
@ -134,10 +133,9 @@ __device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta,
// RmsNorm implementation adapted from ggml, accumulation is made using f32.
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L523
template <typename T>
__device__ void rmsnorm(const T * x, T * dst, const T * alpha, const int ncols, const float eps) {
__device__ void rmsnorm(const T * x, T * dst, const T * alpha, const int ncols, const int block_size, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
const int block_size = blockDim.x;
float tmp = 0.0f; // partial sum for thread in warp
@ -530,15 +528,15 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
#define RMSNORM_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const TYPENAME *src, TYPENAME *dst, const TYPENAME *alpha, \
const int n_cols, const float eps) { \
rmsnorm<TYPENAME>(src, dst, alpha, n_cols, eps); \
const int n_cols, const int block_size, const float eps) { \
rmsnorm<TYPENAME>(src, dst, alpha, n_cols, block_size, eps); \
} \
#define LAYERNORM_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const TYPENAME *src, TYPENAME *dst, const TYPENAME *alpha, \
const TYPENAME *beta, const int n_cols, const float eps) { \
layernorm<TYPENAME>(src, dst, alpha, beta, n_cols, eps); \
const TYPENAME *beta, const int n_cols, const int block_size, const float eps) { \
layernorm<TYPENAME>(src, dst, alpha, beta, n_cols, block_size, eps); \
} \
#define ROPE_OP(TYPENAME, FN_NAME, FN_NAME_I, FN_NAME_THD) \

View File

@ -543,15 +543,23 @@ impl candle::CustomOp2 for RmsNorm {
let dim_m1 = dims[dims.len() - 1];
let (n_rows, n_cols) = (el / dim_m1, dim_m1);
let block_size = if n_cols < 1024 { 32 } else { 1024 };
let cfg = LaunchConfig {
grid_dim: (n_rows as u32, 1, 1),
block_dim: (1024, 1, 1),
block_dim: (block_size, 1, 1),
shared_mem_bytes: 0,
};
let func = dev.get_or_load_func(&kernel_name::<T>("rmsnorm"), kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (&src, &dst, &alpha, n_cols as i32, self.eps);
let params = (
&src,
&dst,
&alpha,
n_cols as i32,
block_size as i32,
self.eps,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)
@ -776,15 +784,24 @@ impl candle::CustomOp3 for LayerNorm {
let dim_m1 = dims[dims.len() - 1];
let (n_rows, n_cols) = (el / dim_m1, dim_m1);
let block_size = if n_cols < 1024 { 32 } else { 1024 };
let cfg = LaunchConfig {
grid_dim: (n_rows as u32, 1, 1),
block_dim: (1024, 1, 1),
block_dim: (block_size, 1, 1),
shared_mem_bytes: 0,
};
let func = dev.get_or_load_func(&kernel_name::<T>("layernorm"), kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (&src, &dst, &alpha, &beta, n_cols as i32, self.eps);
let params = (
&src,
&dst,
&alpha,
&beta,
n_cols as i32,
block_size as i32,
self.eps,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)

View File

@ -77,6 +77,27 @@ fn rms_norm(device: &Device) -> Result<()> {
Ok(())
}
fn rms_norml(device: &Device) -> Result<()> {
use rand::{rngs::StdRng, Rng, SeedableRng};
let (b_size, seq_len, head_dim) = (24, 70, 64);
let el_count = b_size * seq_len * head_dim;
let mut rng = StdRng::seed_from_u64(299792458);
let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect();
let tensor = Tensor::new(src, device)?.reshape((b_size, seq_len, head_dim))?;
let alpha = Tensor::ones(head_dim, candle::DType::F32, device)?;
let t = candle_nn::ops::rms_norm(&tensor, &alpha, 1e-5)?;
let t2 = candle_nn::ops::rms_norm_slow(&tensor, &alpha, 1e-5)?;
let diff = (t - t2)?
.abs()?
.flatten_all()?
.max(0)?
.reshape(())?
.to_vec0::<f32>()?;
assert!(diff < 1e-5);
Ok(())
}
fn layer_norm(device: &Device) -> Result<()> {
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
let tensor = Tensor::new(data, device)?;
@ -103,6 +124,28 @@ fn layer_norm(device: &Device) -> Result<()> {
Ok(())
}
fn layer_norml(device: &Device) -> Result<()> {
use rand::{rngs::StdRng, Rng, SeedableRng};
let (b_size, seq_len, head_dim) = (24, 70, 64);
let el_count = b_size * seq_len * head_dim;
let mut rng = StdRng::seed_from_u64(299792458);
let src: Vec<f32> = (0..el_count).map(|_| rng.gen::<f32>()).collect();
let tensor = Tensor::new(src, device)?.reshape((b_size, seq_len, head_dim))?;
let alpha = Tensor::ones(head_dim, candle::DType::F32, device)?;
let beta = Tensor::zeros(head_dim, candle::DType::F32, device)?;
let t = candle_nn::ops::layer_norm(&tensor, &alpha, &beta, 1e-5)?;
let t2 = candle_nn::ops::layer_norm_slow(&tensor, &alpha, &beta, 1e-5)?;
let diff = (t - t2)?
.abs()?
.flatten_all()?
.max(0)?
.reshape(())?
.to_vec0::<f32>()?;
assert!(diff < 1e-5);
Ok(())
}
#[test]
fn softmax_numerical_stability() -> Result<()> {
let dev = &Device::Cpu;
@ -211,5 +254,7 @@ test_device!(rope, rope_cpu, rope_gpu, rope_metal);
test_device!(rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal);
test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal);
test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal);
test_device!(rms_norml, rms_norml_cpu, rms_norml_gpu, rms_norml_metal);
test_device!(layer_norm, ln_cpu, ln_gpu, ln_metal);
test_device!(layer_norml, lnl_cpu, lnl_gpu, lnl_metal);
test_device!(sigmoid, sigmoid_cpu, sigmoid_gpu, sigmoid_metal);