Compare commits

..

14 Commits

Author SHA1 Message Date
84cd5158ad Update gemm requirement from 0.17.0 to 0.18.0
Updates the requirements on [gemm](https://github.com/sarah-ek/gemm) to permit the latest version.
- [Commits](https://github.com/sarah-ek/gemm/compare/gemm@0.17.0...gemm@0.17.1)

---
updated-dependencies:
- dependency-name: gemm
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-06-01 06:19:34 +00:00
7abc3b8cd7 Bump cudarc version to 0.11.4 (#2230) 2024-06-01 08:18:35 +02:00
46012ed31f Another cudarc update. (#2229) 2024-05-30 22:27:06 +02:00
f3fade3b03 Update cudarc to 0.11.2. (#2227) 2024-05-29 18:50:52 +02:00
ea260aeffd Add Debug, Clone, Deserialize to moondream config (#2222) 2024-05-28 06:08:00 +02:00
0814dfd148 Add a metal kernel for col2im1d. (#2214)
* Add a metal kernel for col2im1d.

* Enable the col2im variant.

* Bugfix.

* Revert the quantized tweak.
2024-05-25 11:03:23 +02:00
3ceca9901a Enable the new layer-norm. (#2213)
* Enable the new layer-norm.

* Shape fixes.
2024-05-24 16:48:21 +02:00
1df2bddccf Add the layernorm specialized op. (#2212)
* Add the layernorm cuda kernels.

* Dedicated layer norm op.

* Add the slower variant.

* Plug the cuda implementation.

* Add the metal variant.

* Add a dedicated test.

* Bugfix.
2024-05-24 15:58:01 +02:00
6f0b807ffd More efficient cuda implementation for ConvTranspose1d. (#2211)
* More efficient cuda implementation for ConvTranspose1d.

* Small tweak.
2024-05-24 11:05:43 +02:00
d54e02d73d Avoid a contiguous call in the quantized phi 3 model. (#2209)
* Simplify the KvCache api.

* Avoid a contiguous call in the quantized phi3 model.
2024-05-23 21:24:55 +02:00
45e235a747 Simplify the KvCache api. (#2207) 2024-05-23 17:07:21 +02:00
31cf64147b Add a couple kv-cache helper functions. (#2206) 2024-05-23 16:21:47 +02:00
77ea479a18 Add Phi-3 Medium (#2205) 2024-05-23 13:33:17 +02:00
72e7ca529a Add some missing where-cond kernels for metal. (#2203) 2024-05-22 09:44:52 +02:00
21 changed files with 1015 additions and 124 deletions

View File

@ -43,9 +43,9 @@ candle-onnx = { path = "./candle-onnx", version = "0.5.1" }
candle-transformers = { path = "./candle-transformers", version = "0.5.1" } candle-transformers = { path = "./candle-transformers", version = "0.5.1" }
clap = { version = "4.2.4", features = ["derive"] } clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false } criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.11.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false } cudarc = { version = "0.11.4", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
fancy-regex = "0.13.0" fancy-regex = "0.13.0"
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] } gemm = { version = "0.18.0", features = ["wasm-simd128-enable"] }
hf-hub = "0.3.0" hf-hub = "0.3.0"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
hound = "3.5.1" hound = "3.5.1"

View File

@ -10,7 +10,7 @@ pub use utils::{
}; };
const USE_IM2COL_CONV1D: bool = true; const USE_IM2COL_CONV1D: bool = true;
const USE_IM2COL_CONV1D_TR: bool = true; const USE_COL2IM_CONV1D_TR: bool = true;
const USE_IM2COL_CONV2D: bool = true; const USE_IM2COL_CONV2D: bool = true;
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator + // TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
@ -2249,7 +2249,7 @@ impl BackendStorage for CpuStorage {
&& params.dilation == 1 && params.dilation == 1
&& params.padding == 0 && params.padding == 0
&& params.output_padding == 0; && params.output_padding == 0;
if USE_IM2COL_CONV1D_TR && can_use_col2im { if USE_COL2IM_CONV1D_TR && can_use_col2im {
let (b_size, c_in, l_in) = l.shape().dims3()?; let (b_size, c_in, l_in) = l.shape().dims3()?;
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?; let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
if !kernel_l.is_contiguous() { if !kernel_l.is_contiguous() {

View File

@ -16,7 +16,7 @@ mod error;
mod utils; mod utils;
pub use device::{CudaDevice, DeviceId}; pub use device::{CudaDevice, DeviceId};
pub use error::{CudaError, WrapErr}; pub use error::{CudaError, WrapErr};
pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, S}; pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, Map3, S};
pub enum SlicePtrOrNull<T> { pub enum SlicePtrOrNull<T> {
Ptr(CudaSlice<T>), Ptr(CudaSlice<T>),
@ -630,6 +630,31 @@ impl<'a> Map2 for Conv2D<'a> {
} }
} }
struct Col2Im1D {
stride: usize,
}
impl Map1 for Col2Im1D {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
col: &CudaSlice<T>,
dev: &CudaDevice,
l: &Layout,
) -> Result<CudaSlice<T>> {
let (b_size, l_in, c_out, k_size) = l.shape().dims4()?;
let stride = self.stride;
let l_out = (l_in - 1) * stride + k_size;
let dst_el = b_size * c_out * l_out;
let mut im = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let params = (dst_el, l_out, l_in, c_out, k_size, stride, col, &mut im);
let func = dev.get_or_load_func(&kernel_name::<T>("col2im1d"), kernels::CONV)?;
unsafe { func.launch(cfg, params) }.w()?;
Ok(im)
}
}
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D); struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
impl<'a> Map2 for ConvTranspose1D<'a> { impl<'a> Map2 for ConvTranspose1D<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
@ -1366,9 +1391,55 @@ impl BackendStorage for CudaStorage {
kernel_l: &Layout, kernel_l: &Layout,
params: &crate::conv::ParamsConvTranspose1D, params: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> { ) -> Result<Self> {
const USE_COL2IM_CONV1D_TR: bool = true;
let device = self.device().clone(); let device = self.device().clone();
let slice = let can_use_col2im = kernel_l.is_contiguous()
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?; && params.dilation == 1
&& params.padding == 0
&& params.output_padding == 0;
let slice = if USE_COL2IM_CONV1D_TR && can_use_col2im {
let (b_size, c_in, l_in) = l.shape().dims3()?;
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
if !kernel_l.is_contiguous() {
crate::bail!(
"convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}"
)
}
if c_in != c_in2 {
crate::bail!(
"convtr1d: shape mismatch on c_in {:?} {:?}",
l.shape(),
kernel_l.shape()
)
}
let col = {
// This merges the last two dimensions of the kernel together.
let kernel_l_mm = Layout::new(
(b_size, c_in, k_size * c_out).into(),
vec![0, k_size * c_out, 1],
kernel_l.start_offset(),
);
self.matmul(
kernel,
(
b_size,
/* m */ l_in,
/* n */ c_out * k_size,
/* k */ c_in,
),
&l.transpose(1, 2)?,
&kernel_l_mm,
)?
};
let col_l = Layout::contiguous((b_size, l_in, c_out, k_size));
Col2Im1D {
stride: params.stride,
}
.map(&col.slice, &device, &col_l)?
} else {
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?
};
Ok(Self { slice, device }) Ok(Self { slice, device })
} }

View File

@ -54,6 +54,44 @@ pub trait Map2 {
} }
} }
pub trait Map3 {
#[allow(clippy::too_many_arguments)]
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src1: &CudaSlice<T>,
layout1: &Layout,
src2: &CudaSlice<T>,
layout2: &Layout,
src3: &CudaSlice<T>,
layout3: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>>;
#[allow(clippy::too_many_arguments)]
fn map(
&self,
s1: &S,
l1: &Layout,
s2: &S,
l2: &Layout,
s3: &S,
l3: &Layout,
d: &CudaDevice,
) -> Result<S> {
let out = match (s1, s2, s3) {
(S::U8(s1), S::U8(s2), S::U8(s3)) => S::U8(self.f(s1, l1, s2, l2, s3, l3, d)?),
(S::U32(s1), S::U32(s2), S::U32(s3)) => S::U32(self.f(s1, l1, s2, l2, s3, l3, d)?),
(S::I64(s1), S::I64(s2), S::I64(s3)) => S::I64(self.f(s1, l1, s2, l2, s3, l3, d)?),
(S::BF16(s1), S::BF16(s2), S::BF16(s3)) => S::BF16(self.f(s1, l1, s2, l2, s3, l3, d)?),
(S::F16(s1), S::F16(s2), S::F16(s3)) => S::F16(self.f(s1, l1, s2, l2, s3, l3, d)?),
(S::F32(s1), S::F32(s2), S::F32(s3)) => S::F32(self.f(s1, l1, s2, l2, s3, l3, d)?),
(S::F64(s1), S::F64(s2), S::F64(s3)) => S::F64(self.f(s1, l1, s2, l2, s3, l3, d)?),
_ => Err(CudaError::InternalError("dtype mismatch in ternary op"))?,
};
Ok(out)
}
}
pub trait Map2InPlace { pub trait Map2InPlace {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self, &self,

View File

@ -824,44 +824,102 @@ impl BackendStorage for MetalStorage {
k_layout: &Layout, k_layout: &Layout,
params: &ParamsConvTranspose1D, params: &ParamsConvTranspose1D,
) -> Result<Self> { ) -> Result<Self> {
const USE_COL2IM_CONV1D_TR: bool = true;
let can_use_col2im = k_layout.is_contiguous()
&& params.dilation == 1
&& params.padding == 0
&& params.output_padding == 0;
let l_out = params.l_out(); let l_out = params.l_out();
let dst_el = params.c_out * l_out * params.b_size; let dst_el = params.c_out * l_out * params.b_size;
let buffer = self
.device
.new_buffer(dst_el, self.dtype, "conv_transpose1d")?;
let command_buffer = self.device.command_buffer()?; let buffer = if USE_COL2IM_CONV1D_TR && can_use_col2im {
let name = match self.dtype { let (b_size, c_in, l_in) = layout.shape().dims3()?;
DType::F32 => "conv_transpose1d_f32", let (c_in2, c_out, k_size) = k_layout.shape().dims3()?;
DType::F16 => "conv_transpose1d_f16", if c_in != c_in2 {
DType::BF16 => "conv_transpose1d_bf16", crate::bail!(
DType::U32 => "conv_transpose1d_u32", "convtr1d: shape mismatch on c_in {:?} {:?}",
DType::U8 => "conv_transpose1d_u8", layout.shape(),
dtype => crate::bail!("Metal conv_transpose1d {dtype:?} not implemented"), k_layout.shape()
)
}
let buffer = self
.device
.new_buffer(dst_el, self.dtype, "conv_transpose1d")?;
let command_buffer = self.device.command_buffer()?;
let name = match self.dtype {
DType::F32 => "col2im1d_f32",
DType::U32 => "col2im1d_u32",
DType::U8 => "col2im1d_u8",
dtype => crate::bail!("metal col2im1d {dtype:?} not implemented"),
};
let col = {
// This merges the last two dimensions of the kernel together.
let kernel_l_mm = Layout::new(
(b_size, c_in, k_size * c_out).into(),
vec![0, k_size * c_out, 1],
k_layout.start_offset(),
);
self.matmul(
k,
(b_size, l_in, c_out * k_size, c_in),
&layout.transpose(1, 2)?,
&kernel_l_mm,
)?
};
candle_metal_kernels::call_col2im1d(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
&[b_size, l_in, c_out, k_size],
params.k_size,
params.stride,
BufferOffset::zero_offset(&col.buffer),
&buffer,
)
.map_err(MetalError::from)?;
buffer
} else {
let buffer = self
.device
.new_buffer(dst_el, self.dtype, "conv_transpose1d")?;
let command_buffer = self.device.command_buffer()?;
let name = match self.dtype {
DType::F32 => "conv_transpose1d_f32",
DType::F16 => "conv_transpose1d_f16",
DType::BF16 => "conv_transpose1d_bf16",
DType::U32 => "conv_transpose1d_u32",
DType::U8 => "conv_transpose1d_u8",
dtype => crate::bail!("Metal conv_transpose1d {dtype:?} not implemented"),
};
candle_metal_kernels::call_conv_transpose1d(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
params.dilation,
params.stride,
params.padding,
params.output_padding,
params.c_out,
l_out,
params.b_size,
layout.dims(),
layout.stride(),
k_layout.dims(),
k_layout.stride(),
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
&k.buffer,
k_layout.start_offset() * k.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;
buffer
}; };
candle_metal_kernels::call_conv_transpose1d(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
params.dilation,
params.stride,
params.padding,
params.output_padding,
params.c_out,
l_out,
params.b_size,
layout.dims(),
layout.stride(),
k_layout.dims(),
k_layout.stride(),
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
&k.buffer,
k_layout.start_offset() * k.dtype.size_in_bytes(),
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype)) Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype))
} }

View File

@ -141,6 +141,8 @@ enum WhichModel {
V2, V2,
#[value(name = "3")] #[value(name = "3")]
V3, V3,
#[value(name = "3-medium")]
V3Medium,
#[value(name = "2-old")] #[value(name = "2-old")]
V2Old, V2Old,
PuffinPhiV2, PuffinPhiV2,
@ -254,6 +256,7 @@ fn main() -> Result<()> {
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(), WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
WhichModel::V2 | WhichModel::V2Old => "microsoft/phi-2".to_string(), WhichModel::V2 | WhichModel::V2Old => "microsoft/phi-2".to_string(),
WhichModel::V3 => "microsoft/Phi-3-mini-4k-instruct".to_string(), WhichModel::V3 => "microsoft/Phi-3-mini-4k-instruct".to_string(),
WhichModel::V3Medium => "microsoft/Phi-3-medium-4k-instruct".to_string(),
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
"lmz/candle-quantized-phi".to_string() "lmz/candle-quantized-phi".to_string()
} }
@ -273,6 +276,7 @@ fn main() -> Result<()> {
WhichModel::V2Old => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(), WhichModel::V2Old => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(),
WhichModel::V2 WhichModel::V2
| WhichModel::V3 | WhichModel::V3
| WhichModel::V3Medium
| WhichModel::PuffinPhiV2 | WhichModel::PuffinPhiV2
| WhichModel::PhiHermes => "main".to_string(), | WhichModel::PhiHermes => "main".to_string(),
} }
@ -287,7 +291,8 @@ fn main() -> Result<()> {
| WhichModel::V1_5 | WhichModel::V1_5
| WhichModel::V2 | WhichModel::V2
| WhichModel::V2Old | WhichModel::V2Old
| WhichModel::V3 => repo.get("tokenizer.json")?, | WhichModel::V3
| WhichModel::V3Medium => repo.get("tokenizer.json")?,
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => { WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
repo.get("tokenizer-puffin-phi-v2.json")? repo.get("tokenizer-puffin-phi-v2.json")?
} }
@ -303,14 +308,14 @@ fn main() -> Result<()> {
WhichModel::V2 | WhichModel::V2Old => vec![repo.get("model-v2-q4k.gguf")?], WhichModel::V2 | WhichModel::V2Old => vec![repo.get("model-v2-q4k.gguf")?],
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?], WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?], WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
WhichModel::V3 => anyhow::bail!( WhichModel::V3 | WhichModel::V3Medium => anyhow::bail!(
"use the quantized or quantized-phi examples for quantized phi-v3" "use the quantized or quantized-phi examples for quantized phi-v3"
), ),
} }
} else { } else {
match args.model { match args.model {
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?], WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
WhichModel::V2 | WhichModel::V2Old | WhichModel::V3 => { WhichModel::V2 | WhichModel::V2Old | WhichModel::V3 | WhichModel::V3Medium => {
candle_examples::hub_load_safetensors( candle_examples::hub_load_safetensors(
&repo, &repo,
"model.safetensors.index.json", "model.safetensors.index.json",
@ -332,7 +337,7 @@ fn main() -> Result<()> {
WhichModel::V2 | WhichModel::V2Old => Config::v2(), WhichModel::V2 | WhichModel::V2Old => Config::v2(),
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(), WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
WhichModel::PhiHermes => Config::phi_hermes_1_3b(), WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
WhichModel::V3 => { WhichModel::V3 | WhichModel::V3Medium => {
panic!("use the quantized or quantized-phi examples for quantized phi-v3") panic!("use the quantized or quantized-phi examples for quantized phi-v3")
} }
}; };
@ -352,7 +357,9 @@ fn main() -> Result<()> {
let dtype = match args.dtype { let dtype = match args.dtype {
Some(dtype) => std::str::FromStr::from_str(&dtype)?, Some(dtype) => std::str::FromStr::from_str(&dtype)?,
None => { None => {
if args.model == WhichModel::V3 && device.is_cuda() { if (args.model == WhichModel::V3 || args.model == WhichModel::V3Medium)
&& device.is_cuda()
{
DType::BF16 DType::BF16
} else { } else {
DType::F32 DType::F32
@ -368,7 +375,7 @@ fn main() -> Result<()> {
let phi = Phi::new(&config, vb)?; let phi = Phi::new(&config, vb)?;
Model::Phi(phi) Model::Phi(phi)
} }
WhichModel::V3 => { WhichModel::V3 | WhichModel::V3Medium => {
let config_filename = repo.get("config.json")?; let config_filename = repo.get("config.json")?;
let config = std::fs::read_to_string(config_filename)?; let config = std::fs::read_to_string(config_filename)?;
let config: Phi3Config = serde_json::from_str(&config)?; let config: Phi3Config = serde_json::from_str(&config)?;

View File

@ -217,7 +217,6 @@ fn main() -> anyhow::Result<()> {
match args.which { match args.which {
Which::Phi2 => Model::Phi2(Phi2::from_gguf(model, &mut file, &device)?), Which::Phi2 => Model::Phi2(Phi2::from_gguf(model, &mut file, &device)?),
Which::Phi3 => Model::Phi3(Phi3::from_gguf( Which::Phi3 => Model::Phi3(Phi3::from_gguf(
1,
args.use_flash_attn, args.use_flash_attn,
model, model,
&mut file, &mut file,

View File

@ -97,6 +97,50 @@ __device__ void im2col1d(
} }
} }
template <typename T>
__device__ void col2im1d(
const size_t dst_el,
const size_t l_out,
const size_t l_in,
const size_t c_out,
const size_t k_size,
const size_t stride,
const T *src,
T *dst
) {
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
// src: (b_size, l_in, c_out, l_k)
// dst: (b_size, c_out, l_out)
if (dst_i >= dst_el) {
return;
}
const size_t dst_s0 = c_out * l_out;
const size_t dst_s1 = l_out;
const size_t src_s0 = c_out * k_size * l_in;
const size_t src_s1 = c_out * k_size;
const size_t src_s2 = k_size;
size_t tmp_dst_i = dst_i;
const size_t b_idx = tmp_dst_i / dst_s0;
tmp_dst_i -= b_idx * dst_s0;
const size_t c_idx = tmp_dst_i / dst_s1;
tmp_dst_i -= c_idx * dst_s1;
const int l_out_idx = tmp_dst_i;
dst[dst_i] = static_cast<T>(0);
int l_in_idx = l_out_idx / stride;
int k0 = l_out_idx - l_in_idx * stride;
// l_out_idx = l_in_idx * stride + k0
for (; k0 < k_size && l_in_idx >= 0; k0 += stride, --l_in_idx) {
if (l_in_idx < l_in) {
const size_t src_i = b_idx * src_s0 + l_in_idx * src_s1 + c_idx * src_s2 + k0;
dst[dst_i] += src[src_i];
}
}
}
template <typename T> template <typename T>
__device__ void im2col( __device__ void im2col(
const size_t dst_numel, const size_t dst_numel,
@ -542,6 +586,20 @@ extern "C" __global__ void FN_NAME( \
im2col1d<TYPENAME>(dst_numel, l_out, l_k, stride, padding, dilation, info, src, dst); \ im2col1d<TYPENAME>(dst_numel, l_out, l_k, stride, padding, dilation, info, src, dst); \
} \ } \
#define COL2IM1D_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t dst_el, \
const size_t l_out, \
const size_t l_in, \
const size_t c_out, \
const size_t k_size, \
const size_t stride, \
const TYPENAME *src, \
TYPENAME *dst \
) { \
col2im1d<TYPENAME>(dst_el, l_out, l_in, c_out, k_size, stride, src, dst); \
} \
#define IM2COL_OP(TYPENAME, FN_NAME) \ #define IM2COL_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \ extern "C" __global__ void FN_NAME( \
const size_t dst_numel, \ const size_t dst_numel, \
@ -643,6 +701,7 @@ MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16) UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
IM2COL_OP(__nv_bfloat16, im2col_bf16) IM2COL_OP(__nv_bfloat16, im2col_bf16)
IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16) IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16)
COL2IM1D_OP(__nv_bfloat16, col2im1d_bf16)
#endif #endif
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
@ -655,6 +714,7 @@ MAX_POOL2D_OP(__half, max_pool2d_f16)
UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16) UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16)
IM2COL_OP(__half, im2col_f16) IM2COL_OP(__half, im2col_f16)
IM2COL1D_OP(__half, im2col1d_f16) IM2COL1D_OP(__half, im2col1d_f16)
COL2IM1D_OP(__half, col2im1d_f16)
#endif #endif
CONV1D_OP(float, float, conv1d_f32) CONV1D_OP(float, float, conv1d_f32)
@ -701,3 +761,8 @@ IM2COL1D_OP(float, im2col1d_f32)
IM2COL1D_OP(double, im2col1d_f64) IM2COL1D_OP(double, im2col1d_f64)
IM2COL1D_OP(uint8_t, im2col1d_u8) IM2COL1D_OP(uint8_t, im2col1d_u8)
IM2COL1D_OP(uint32_t, im2col1d_u32) IM2COL1D_OP(uint32_t, im2col1d_u32)
COL2IM1D_OP(float, col2im1d_f32)
COL2IM1D_OP(double, col2im1d_f64)
COL2IM1D_OP(uint8_t, col2im1d_u8)
COL2IM1D_OP(uint32_t, col2im1d_u32)

View File

@ -50,6 +50,15 @@ fast_sum(const size_t src_numel, const size_t el_to_sum_per_block,
dst[dst_id] = shr[0]; dst[dst_id] = shr[0];
} }
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
}
return a;
}
static __device__ __forceinline__ float warp_reduce_sum(float x) { static __device__ __forceinline__ float warp_reduce_sum(float x) {
#pragma unroll #pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) { for (int mask = 16; mask > 0; mask >>= 1) {
@ -58,6 +67,70 @@ static __device__ __forceinline__ float warp_reduce_sum(float x) {
return x; return x;
} }
// LayerNorm implementation adapted from ggml, accumulation is made using f32.
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L477
template <typename T>
__device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
const int block_size = blockDim.x;
float2 mean_var = make_float2(0.f, 0.f);
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[row*ncols + col];
mean_var.x += xi;
mean_var.y += xi * xi;
}
// sum up partial sums
mean_var = warp_reduce_sum(mean_var);
if (block_size > WARP_SIZE) {
__shared__ float2 s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = mean_var;
}
__syncthreads();
mean_var = s_sum[lane_id];
mean_var = warp_reduce_sum(mean_var);
}
const float mean = mean_var.x / ncols;
const float var = mean_var.y / ncols - mean * mean;
const float inv_std = rsqrtf(var + eps);
if (alpha == nullptr && beta == nullptr) {
for (int col = tid; col < ncols; col += block_size) {
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
dst[row*ncols + col] = static_cast<T>(lhs);
}
}
else if (alpha == nullptr && beta != nullptr) {
for (int col = tid; col < ncols; col += block_size) {
float b = static_cast<float>(beta[col]);
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
dst[row*ncols + col] = static_cast<T>(lhs + b);
}
}
else if (alpha != nullptr && beta == nullptr) {
for (int col = tid; col < ncols; col += block_size) {
float a = static_cast<float>(alpha[col]);
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
dst[row*ncols + col] = static_cast<T>(lhs * a);
}
}
else {
for (int col = tid; col < ncols; col += block_size) {
float a = static_cast<float>(alpha[col]);
float b = static_cast<float>(beta[col]);
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
dst[row*ncols + col] = static_cast<T>(lhs * a + b);
}
}
}
// RmsNorm implementation adapted from ggml, accumulation is made using f32. // RmsNorm implementation adapted from ggml, accumulation is made using f32.
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L523 // https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L523
template <typename T> template <typename T>
@ -461,6 +534,13 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
rmsnorm<TYPENAME>(src, dst, alpha, n_cols, eps); \ rmsnorm<TYPENAME>(src, dst, alpha, n_cols, eps); \
} \ } \
#define LAYERNORM_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const TYPENAME *src, TYPENAME *dst, const TYPENAME *alpha, \
const TYPENAME *beta, const int n_cols, const float eps) { \
layernorm<TYPENAME>(src, dst, alpha, beta, n_cols, eps); \
} \
#define ROPE_OP(TYPENAME, FN_NAME, FN_NAME_I, FN_NAME_THD) \ #define ROPE_OP(TYPENAME, FN_NAME, FN_NAME_I, FN_NAME_THD) \
extern "C" __global__ void FN_NAME_I( \ extern "C" __global__ void FN_NAME_I( \
const TYPENAME *src, \ const TYPENAME *src, \
@ -496,6 +576,7 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
#if __CUDA_ARCH__ >= 800 #if __CUDA_ARCH__ >= 800
SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16) SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16) RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16)
LAYERNORM_OP(__nv_bfloat16, layernorm_bf16)
ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16) ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16)
SUM_OP(__nv_bfloat16, sum_bf16) SUM_OP(__nv_bfloat16, sum_bf16)
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16) FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16)
@ -504,6 +585,7 @@ FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argm
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
SOFTMAX_OP(__half, float, softmax_f16) SOFTMAX_OP(__half, float, softmax_f16)
RMSNORM_OP(__half, rmsnorm_f16) RMSNORM_OP(__half, rmsnorm_f16)
LAYERNORM_OP(__half, layernorm_f16)
ROPE_OP(__half, rope_f16, rope_i_f16, rope_thd_f16) ROPE_OP(__half, rope_f16, rope_i_f16, rope_thd_f16)
SUM_OP(__half, sum_f16) SUM_OP(__half, sum_f16)
FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16) FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16)
@ -516,6 +598,8 @@ SOFTMAX_OP(float, float, softmax_f32)
SOFTMAX_OP(double, double, softmax_f64) SOFTMAX_OP(double, double, softmax_f64)
RMSNORM_OP(float, rmsnorm_f32) RMSNORM_OP(float, rmsnorm_f32)
RMSNORM_OP(double, rmsnorm_f64) RMSNORM_OP(double, rmsnorm_f64)
LAYERNORM_OP(float, layernorm_f32)
LAYERNORM_OP(double, layernorm_f64)
ROPE_OP(float, rope_f32, rope_i_f32, rope_thd_f32) ROPE_OP(float, rope_f32, rope_i_f32, rope_thd_f32)
ROPE_OP(double, rope_f64, rope_i_f64, rope_thd_f64) ROPE_OP(double, rope_f64, rope_i_f64, rope_thd_f64)

View File

@ -68,6 +68,50 @@ METAL_FUNC void im2col(
} }
} }
template <typename T>
METAL_FUNC void col2im1d(
constant size_t &dst_el,
constant size_t &l_out,
constant size_t &l_in,
constant size_t &c_out,
constant size_t &k_size,
constant size_t &stride,
device const T *src,
device T *dst,
uint dst_i [[ thread_position_in_grid ]]
) {
// src: (b_size, l_in, c_out, l_k)
// dst: (b_size, c_out, l_out)
if (dst_i >= dst_el) {
return;
}
const size_t dst_s0 = c_out * l_out;
const size_t dst_s1 = l_out;
const size_t src_s0 = c_out * k_size * l_in;
const size_t src_s1 = c_out * k_size;
const size_t src_s2 = k_size;
size_t tmp_dst_i = dst_i;
const size_t b_idx = tmp_dst_i / dst_s0;
tmp_dst_i -= b_idx * dst_s0;
const size_t c_idx = tmp_dst_i / dst_s1;
tmp_dst_i -= c_idx * dst_s1;
const int l_out_idx = tmp_dst_i;
dst[dst_i] = static_cast<T>(0);
int l_in_idx = l_out_idx / stride;
int k0 = l_out_idx - l_in_idx * stride;
// l_out_idx = l_in_idx * stride + k0
for (; k0 < k_size && l_in_idx >= 0; k0 += stride, --l_in_idx) {
if (l_in_idx < l_in) {
const size_t src_i = b_idx * src_s0 + l_in_idx * src_s1 + c_idx * src_s2 + k0;
dst[dst_i] += src[src_i];
}
}
}
template <typename T> template <typename T>
METAL_FUNC void im2col1d( METAL_FUNC void im2col1d(
constant size_t &dst_numel, constant size_t &dst_numel,
@ -191,6 +235,21 @@ kernel void FN_NAME( \
im2col1d<T>(dst_numel, l_out, l_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \ im2col1d<T>(dst_numel, l_out, l_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \
} \ } \
#define COL2IM1D_OP(T, FN_NAME) \
kernel void FN_NAME( \
constant size_t &dst_el, \
constant size_t &l_out, \
constant size_t &l_in, \
constant size_t &c_out, \
constant size_t &k_size, \
constant size_t &stride, \
device const T *src, \
device T *dst, \
uint tid [[ thread_position_in_grid ]] \
) { \
col2im1d<T>(dst_el, l_out, l_in, c_out, k_size, stride, src, dst, tid); \
} \
#define UPSAMPLE_NEAREST2D_OP(TYPENAME, FN_NAME) \ #define UPSAMPLE_NEAREST2D_OP(TYPENAME, FN_NAME) \
kernel void FN_NAME( \ kernel void FN_NAME( \
constant size_t &w_out, \ constant size_t &w_out, \
@ -493,6 +552,10 @@ IM2COL_OP(uint32_t, im2col_u32)
IM2COL_OP(bfloat, im2col_bf16) IM2COL_OP(bfloat, im2col_bf16)
#endif #endif
COL2IM1D_OP(float, col2im1d_f32)
COL2IM1D_OP(uint8_t, col2im1d_u8)
COL2IM1D_OP(uint32_t, col2im1d_u32)
IM2COL1D_OP(float, im2col1d_f32) IM2COL1D_OP(float, im2col1d_f32)
IM2COL1D_OP(uint8_t, im2col1d_u8) IM2COL1D_OP(uint8_t, im2col1d_u8)
IM2COL1D_OP(uint32_t, im2col1d_u32) IM2COL1D_OP(uint32_t, im2col1d_u32)

View File

@ -739,6 +739,69 @@ pub fn call_rms_norm(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.set_threadgroup_memory_length(0, (width * 4).max(16) as u64);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_layer_norm(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
kernel_name: &'static str,
length: usize,
elements_to_sum: usize,
eps: f32,
input: &Buffer,
input_offset: usize,
alpha: &Buffer,
alpha_offset: usize,
beta: &Buffer,
beta_offset: usize,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
length,
elements_to_sum,
(input, input_offset),
output,
(alpha, alpha_offset),
(beta, beta_offset),
eps
)
);
let out_length = length / elements_to_sum;
let thread_group_count = MTLSize {
width: out_length as u64,
height: 1,
depth: 1,
};
let width = std::cmp::min(
pipeline.max_total_threads_per_threadgroup(),
elements_to_sum as u64,
)
.next_power_of_two();
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.set_threadgroup_memory_length(0, (width * 8).max(32) as u64);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
@ -1588,6 +1651,39 @@ pub fn call_im2col1d_strided(
Ok(()) Ok(())
} }
#[allow(clippy::too_many_arguments)]
pub fn call_col2im1d(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
k_size: usize,
stride: usize,
input: BufferOffset,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
let l_in = shape[1];
let c_out = shape[2];
let l_out = (l_in - 1) * stride + k_size;
let dst_el = shape[0] * c_out * l_out;
let encoder = command_buffer.new_compute_command_encoder();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(dst_el, l_out, l_in, c_out, k_size, stride, &input, output)
);
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn call_im2col_strided( pub fn call_im2col_strided(
device: &Device, device: &Device,

View File

@ -353,6 +353,65 @@ METAL_FUNC void rmsnorm(
} }
} }
template<typename T>
METAL_FUNC void layernorm(
constant size_t & src_numel,
constant size_t & el_to_sum_per_block,
device const T * src,
device T * dst,
device const T * alpha,
device const T * beta,
constant float & eps,
uint id,
uint tid,
uint dst_id,
uint block_dim,
threadgroup float * shared_memory
) {
size_t start_idx = dst_id * el_to_sum_per_block;
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
size_t idx = start_idx + tid;
float tmp1 = 0;
float tmp2 = 0;
while (idx < stop_idx) {
tmp1 += float(src[idx]);
tmp2 += float(src[idx]) * float(src[idx]);
idx += block_dim;
}
shared_memory[tid] = tmp1;
shared_memory[tid + block_dim] = tmp2;
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint s = block_dim / 2; s > 0; s >>= 1) {
if (tid < s) {
shared_memory[tid] = shared_memory[tid] + shared_memory[tid + s];
shared_memory[block_dim + tid] = shared_memory[block_dim + tid] + shared_memory[block_dim + tid + s];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
/* wait for shared_memory[0] to be filled */
threadgroup_barrier(mem_flags::mem_threadgroup);
float mean = shared_memory[0] / float(el_to_sum_per_block);
float var = shared_memory[block_dim] / float(el_to_sum_per_block) - mean * mean;
float inv_norm = 1.0f / sqrt(var + eps);
idx = start_idx + tid;
while (idx < stop_idx) {
float val = (float(src[idx]) - mean) * inv_norm;
if (alpha != nullptr) {
val *= float(alpha[idx - start_idx]);
}
if (beta != nullptr) {
val += float(beta[idx - start_idx]);
}
dst[idx] = T(val);
idx += block_dim;
}
}
#define RMSNORM(NAME, T) \ #define RMSNORM(NAME, T) \
kernel void NAME( \ kernel void NAME( \
constant size_t &src_numel, \ constant size_t &src_numel, \
@ -371,6 +430,25 @@ kernel void NAME( \
rmsnorm<T>(src_numel, el_to_sum_per_block, src, dst, alpha, eps, id, tid, dst_id, block_dim, shared_memory); \ rmsnorm<T>(src_numel, el_to_sum_per_block, src, dst, alpha, eps, id, tid, dst_id, block_dim, shared_memory); \
} \ } \
#define LAYERNORM(NAME, T) \
kernel void NAME( \
constant size_t &src_numel, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device T *dst, \
device const T *alpha, \
device const T *beta, \
constant float &eps, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
threadgroup float shared_memory[THREADGROUP_SIZE]; \
shared_memory[tid] = 0; \
layernorm<T>(src_numel, el_to_sum_per_block, src, dst, alpha, beta, eps, id, tid, dst_id, block_dim, shared_memory); \
} \
template<typename T> template<typename T>
METAL_FUNC void ropei( METAL_FUNC void ropei(
constant size_t &bh, constant size_t &bh,
@ -511,6 +589,8 @@ SOFTMAX(softmax_f32, float)
SOFTMAX(softmax_f16, half) SOFTMAX(softmax_f16, half)
RMSNORM(rmsnorm_f32, float) RMSNORM(rmsnorm_f32, float)
RMSNORM(rmsnorm_f16, half) RMSNORM(rmsnorm_f16, half)
LAYERNORM(layernorm_f32, float)
LAYERNORM(layernorm_f16, half)
ROPE(rope_f32, rope_i_f32, rope_thd_f32, float) ROPE(rope_f32, rope_i_f32, rope_thd_f32, float)
ROPE(rope_f16, rope_i_f16, rope_thd_f16, half) ROPE(rope_f16, rope_i_f16, rope_thd_f16, half)
@ -535,5 +615,6 @@ ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF) ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
SOFTMAX(softmax_bf16, bfloat) SOFTMAX(softmax_bf16, bfloat)
RMSNORM(rmsnorm_bf16, bfloat) RMSNORM(rmsnorm_bf16, bfloat)
LAYERNORM(layernorm_bf16, bfloat)
ROPE(rope_bf16, rope_i_bf16, rope_thd_bf16, bfloat) ROPE(rope_bf16, rope_i_bf16, rope_thd_bf16, bfloat)
#endif #endif

View File

@ -1,5 +1,4 @@
#include <metal_stdlib> #include <metal_stdlib>
#
using namespace metal; using namespace metal;
METAL_FUNC uint get_strided_index( METAL_FUNC uint get_strided_index(
@ -57,27 +56,31 @@ kernel void FN_NAME(
where_cond<T, ID>(numel, num_dims, dims, strides, strides_t, strides_f, ids, t, f, out, i); \ where_cond<T, ID>(numel, num_dims, dims, strides, strides_t, strides_f, ids, t, f, out, i); \
} \ } \
// WHERE_OP(float, int64_t, where_i64_f32) WHERE_OP(half, uint32_t, where_u32_f16)
// WHERE_OP(double, int64_t, where_i64_f64) WHERE_OP(float, uint32_t, where_u32_f32)
// WHERE_OP(uint8_t, int64_t, where_i64_u8) WHERE_OP(uint8_t, uint32_t, where_u32_u8)
// WHERE_OP(uint32_t, int64_t, where_i64_u32) WHERE_OP(uint32_t, uint32_t, where_u32_u32)
// WHERE_OP(int64_t, int64_t, where_i64_i64)
//
// WHERE_OP(float, uint32_t, where_u32_f32)
// WHERE_OP(double, uint32_t, where_u32_f64)
// WHERE_OP(uint8_t, uint32_t, where_u32_u8)
// WHERE_OP(uint32_t, uint32_t, where_u32_u32)
// WHERE_OP(int64_t, uint32_t, where_u32_i64)
WHERE_OP(float, uint8_t, where_u8_f32)
WHERE_OP(half, uint8_t, where_u8_f16) WHERE_OP(half, uint8_t, where_u8_f16)
WHERE_OP(float, uint8_t, where_u8_f32)
WHERE_OP(uint8_t, uint8_t, where_u8_u8) WHERE_OP(uint8_t, uint8_t, where_u8_u8)
WHERE_OP(uint32_t, uint8_t, where_u8_u32) WHERE_OP(uint32_t, uint8_t, where_u8_u32)
#if __METAL_VERSION__ >= 220 #if __METAL_VERSION__ >= 220
WHERE_OP(int64_t, uint8_t, where_u8_i64) WHERE_OP(int64_t, uint8_t, where_u8_i64)
WHERE_OP(int64_t, uint32_t, where_u32_i64)
WHERE_OP(half, int64_t, where_i64_f16)
WHERE_OP(float, int64_t, where_i64_f32)
WHERE_OP(uint8_t, int64_t, where_i64_u8)
WHERE_OP(uint32_t, int64_t, where_i64_u32)
WHERE_OP(int64_t, int64_t, where_i64_i64)
#if defined(__HAVE_BFLOAT__)
WHERE_OP(bfloat, int64_t, where_i64_bf16)
#endif
#endif #endif
#if defined(__HAVE_BFLOAT__) #if defined(__HAVE_BFLOAT__)
WHERE_OP(bfloat, uint8_t, where_u8_bf16) WHERE_OP(bfloat, uint8_t, where_u8_bf16)
WHERE_OP(bfloat, uint32_t, where_u32_bf16)
#endif #endif

View File

@ -1,30 +1,25 @@
use candle::{DType, Device, Result, Shape, Tensor}; use candle::{Result, Tensor};
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Cache { pub struct Cache {
all_data: Tensor, // all_data is an option on a Tensor, this makes it possible to only create the actual tensor
// on the first call where the batch size is easily known.
// Also this makes it safe to clone a KvCache that has been reseted (as in it will not share
// its internal state with the cloned instance).
all_data: Option<Tensor>,
dim: usize, dim: usize,
current_seq_len: usize, current_seq_len: usize,
max_seq_len: usize, max_seq_len: usize,
} }
impl Cache { impl Cache {
pub fn new<S: Into<Shape>, D: candle::shape::Dim>( pub fn new(dim: usize, max_seq_len: usize) -> Self {
dim: D, Self {
shape: S, all_data: None,
dtype: DType,
dev: &Device,
) -> Result<Self> {
let shape = shape.into();
let dim = dim.to_index(&shape, "kv-cache")?;
let max_seq_len = shape.dims()[dim];
let all_data = Tensor::zeros(shape, dtype, dev)?;
Ok(Self {
all_data,
dim, dim,
current_seq_len: 0, current_seq_len: 0,
max_seq_len, max_seq_len,
}) }
} }
pub fn dim(&self) -> usize { pub fn dim(&self) -> usize {
@ -39,16 +34,34 @@ impl Cache {
self.max_seq_len self.max_seq_len
} }
pub fn all_data(&self) -> &Tensor { pub fn all_data(&self) -> &Option<Tensor> {
&self.all_data &self.all_data
} }
pub fn current_data(&self) -> Result<Tensor> { pub fn current_data(&self) -> Result<Option<Tensor>> {
self.all_data.narrow(self.dim, 0, self.current_seq_len) let data = match self.all_data.as_ref() {
None => None,
Some(d) => Some(d.narrow(self.dim, 0, self.current_seq_len)?),
};
Ok(data)
}
pub fn reset(&mut self) {
self.current_seq_len = 0;
self.all_data = None;
} }
pub fn append(&mut self, src: &Tensor) -> Result<()> { pub fn append(&mut self, src: &Tensor) -> Result<()> {
let seq_len = src.dim(self.dim)?; let seq_len = src.dim(self.dim)?;
// This doesn't seem very idiomatic but because the creation can fail, it's tricky to use
// self.all_data.get_or_insert_with.
if self.all_data.is_none() {
let mut shape = src.dims().to_vec();
shape[self.dim] = self.max_seq_len;
let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
self.all_data = Some(ad)
};
let ad = self.all_data.as_mut().unwrap();
if self.current_seq_len + seq_len > self.max_seq_len { if self.current_seq_len + seq_len > self.max_seq_len {
candle::bail!( candle::bail!(
"kv-cache: above max-seq-len {}+{seq_len}>{}", "kv-cache: above max-seq-len {}+{seq_len}>{}",
@ -56,8 +69,7 @@ impl Cache {
self.max_seq_len self.max_seq_len
) )
} }
self.all_data ad.slice_set(src, self.dim, self.current_seq_len)?;
.slice_set(src, self.dim, self.current_seq_len)?;
self.current_seq_len += seq_len; self.current_seq_len += seq_len;
Ok(()) Ok(())
} }
@ -70,32 +82,66 @@ pub struct KvCache {
} }
impl KvCache { impl KvCache {
pub fn new<S: Into<Shape>, D: candle::shape::Dim>( pub fn new(dim: usize, max_seq_len: usize) -> Self {
dim: D, let k = Cache::new(dim, max_seq_len);
shape: S, let v = Cache::new(dim, max_seq_len);
dtype: DType, Self { k, v }
dev: &Device,
) -> Result<Self> {
let shape = shape.into();
let dim = dim.to_index(&shape, "kv-cache")?;
let k = Cache::new(dim, &shape, dtype, dev)?;
let v = Cache::new(dim, &shape, dtype, dev)?;
Ok(Self { k, v })
} }
pub fn k(&self) -> Result<Tensor> { pub fn k_cache(&self) -> &Cache {
&self.k
}
pub fn v_cache(&self) -> &Cache {
&self.v
}
pub fn k_cache_mut(&mut self) -> &mut Cache {
&mut self.k
}
pub fn v_cache_mut(&mut self) -> &mut Cache {
&mut self.v
}
pub fn k(&self) -> Result<Option<Tensor>> {
self.k.current_data() self.k.current_data()
} }
pub fn v(&self) -> Result<Tensor> { pub fn v(&self) -> Result<Option<Tensor>> {
self.v.current_data() self.v.current_data()
} }
pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> { pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
self.k.append(k)?; self.k.append(k)?;
self.v.append(v)?; self.v.append(v)?;
let k = self.k.current_data()?; let out_k = self.k.current_data()?;
let v = self.v.current_data()?; let out_v = self.v.current_data()?;
let k = match out_k {
None => {
let mut shape = k.dims().to_vec();
shape[self.k.dim] = 0;
Tensor::zeros(shape, k.dtype(), k.device())?
}
Some(k) => k,
};
let v = match out_v {
None => {
let mut shape = v.dims().to_vec();
shape[self.k.dim] = 0;
Tensor::zeros(shape, v.dtype(), v.device())?
}
Some(v) => v,
};
Ok((k, v)) Ok((k, v))
} }
pub fn current_seq_len(&self) -> usize {
self.k.current_seq_len()
}
pub fn reset(&mut self) {
self.k.reset();
self.v.reset();
}
} }

View File

@ -11,8 +11,8 @@
//! use candle_nn::{LayerNorm, Module}; //! use candle_nn::{LayerNorm, Module};
//! # fn main() -> candle::Result<()> { //! # fn main() -> candle::Result<()> {
//! //!
//! let w = Tensor::new(1f32, &Cpu)?; //! let w = Tensor::new(&[1f32, 1f32, 1f32], &Cpu)?;
//! let b = Tensor::new(0f32, &Cpu)?; //! let b = Tensor::new(&[0f32, 0f32, 0f32], &Cpu)?;
//! let layer = LayerNorm::new(w, b, 1e-5); //! let layer = LayerNorm::new(w, b, 1e-5);
//! //!
//! let xs = Tensor::new( //! let xs = Tensor::new(
@ -107,6 +107,11 @@ impl LayerNorm {
impl Module for LayerNorm { impl Module for LayerNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> { fn forward(&self, x: &Tensor) -> Result<Tensor> {
if x.is_contiguous() && self.remove_mean {
if let Some(bias) = self.bias.as_ref() {
return crate::ops::layer_norm(x, &self.weight, bias, self.eps as f32);
}
}
let x_dtype = x.dtype(); let x_dtype = x.dtype();
let internal_dtype = match x_dtype { let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32, DType::F16 | DType::BF16 => DType::F32,

View File

@ -1,4 +1,4 @@
use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor}; use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor, D};
use rayon::prelude::*; use rayon::prelude::*;
/// Applies the softmax function to the input tensor, rescaling the element so that elements on /// Applies the softmax function to the input tensor, rescaling the element so that elements on
@ -39,7 +39,7 @@ pub fn silu(xs: &Tensor) -> Result<Tensor> {
} }
pub fn swiglu(xs: &Tensor) -> Result<Tensor> { pub fn swiglu(xs: &Tensor) -> Result<Tensor> {
let xs = xs.chunk(2, candle::D::Minus1)?; let xs = xs.chunk(2, D::Minus1)?;
&xs[0].silu()? * &xs[1] &xs[0].silu()? * &xs[1]
} }
@ -620,15 +620,15 @@ pub fn rms_norm_slow(x: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
DType::F16 | DType::BF16 => DType::F32, DType::F16 | DType::BF16 => DType::F32,
d => d, d => d,
}; };
let hidden_size = x.dim(candle::D::Minus1)?; let hidden_size = x.dim(D::Minus1)?;
let x = x.to_dtype(internal_dtype)?; let x = x.to_dtype(internal_dtype)?;
let norm_x = (x.sqr()?.sum_keepdim(candle::D::Minus1)? / hidden_size as f64)?; let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?; let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?;
x_normed.to_dtype(x_dtype)?.broadcast_mul(alpha) x_normed.to_dtype(x_dtype)?.broadcast_mul(alpha)
} }
pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> { pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
let hidden_size_xs = xs.dim(candle::D::Minus1)?; let hidden_size_xs = xs.dim(D::Minus1)?;
let hidden_size_alpha = alpha.dims1()?; let hidden_size_alpha = alpha.dims1()?;
if hidden_size_xs != hidden_size_alpha { if hidden_size_xs != hidden_size_alpha {
candle::bail!( candle::bail!(
@ -640,6 +640,254 @@ pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
xs.apply_op2_no_bwd(alpha, &RmsNorm { eps }) xs.apply_op2_no_bwd(alpha, &RmsNorm { eps })
} }
#[derive(Debug, Clone)]
struct LayerNorm {
eps: f32,
}
impl candle::CustomOp3 for LayerNorm {
fn name(&self) -> &'static str {
"layer-norm"
}
fn cpu_fwd(
&self,
s1: &CpuStorage,
l1: &Layout,
s2: &CpuStorage,
l2: &Layout,
s3: &CpuStorage,
l3: &Layout,
) -> Result<(CpuStorage, Shape)> {
use candle::backend::BackendStorage;
let eps = self.eps;
fn inner<
T: candle::WithDType
+ num_traits::Float
+ num_traits::AsPrimitive<f32>
+ num_traits::FromPrimitive,
>(
src: &[T],
layout: &Layout,
alpha: &[T],
alpha_layout: &Layout,
beta: &[T],
beta_layout: &Layout,
eps: f32,
) -> Result<(CpuStorage, Shape)> {
let src = match layout.contiguous_offsets() {
None => candle::bail!("input has to be contiguous"),
Some((o1, o2)) => &src[o1..o2],
};
let alpha = match alpha_layout.contiguous_offsets() {
None => candle::bail!("alpha has to be contiguous"),
Some((o1, o2)) => &alpha[o1..o2],
};
let beta = match beta_layout.contiguous_offsets() {
None => candle::bail!("beta has to be contiguous"),
Some((o1, o2)) => &beta[o1..o2],
};
let el_count = layout.shape().elem_count();
let dims = layout.shape().dims();
let dim_m1 = dims[dims.len() - 1];
let mut dst = vec![T::zero(); el_count];
src.par_chunks(dim_m1)
.zip(dst.par_chunks_mut(dim_m1))
.for_each(|(src, dst)| {
let mut sum = 0f32;
let mut sum2 = 0f32;
for v in src {
let v = v.as_();
sum += v;
sum2 += v * v;
}
let mean = sum / dim_m1 as f32;
let var = sum2 / dim_m1 as f32 - mean * mean;
let inv_std = (var + eps).sqrt().recip();
for ((d, s), (alpha, beta)) in
dst.iter_mut().zip(src.iter()).zip(alpha.iter().zip(beta))
{
let alpha = alpha.as_();
let beta = beta.as_();
let d_ = (s.as_() - mean) * inv_std * alpha + beta;
*d = T::from_f32(d_).unwrap_or_else(T::nan);
}
});
let storage = candle::WithDType::to_cpu_storage_owned(dst);
Ok((storage, Shape::from_dims(dims)))
}
use CpuStorage as C;
match (s1, s2, s3) {
(C::BF16(s1), C::BF16(s2), C::BF16(s3)) => {
inner::<half::bf16>(s1, l1, s2, l2, s3, l3, eps)
}
(C::F16(s1), C::F16(s2), C::F16(s3)) => inner::<half::f16>(s1, l1, s2, l2, s3, l3, eps),
(C::F32(s1), C::F32(s2), C::F32(s3)) => inner::<f32>(s1, l1, s2, l2, s3, l3, eps),
_ => candle::bail!("unsupported dtype for rmsnorm {:?}", s1.dtype()),
}
}
#[cfg(feature = "cuda")]
fn cuda_fwd(
&self,
s1: &candle::CudaStorage,
l1: &Layout,
s2: &candle::CudaStorage,
l2: &Layout,
s3: &candle::CudaStorage,
l3: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
};
use candle::cuda_backend::{kernel_name, kernels, Map3, WrapErr};
use candle::{CudaDevice, WithDType};
struct S {
eps: f32,
}
impl Map3 for S {
fn f<T: DeviceRepr + WithDType>(
&self,
src: &CudaSlice<T>,
layout: &Layout,
alpha: &CudaSlice<T>,
alpha_layout: &Layout,
beta: &CudaSlice<T>,
beta_layout: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>> {
let src = match layout.contiguous_offsets() {
None => candle::bail!("input has to be contiguous"),
Some((o1, o2)) => src.slice(o1..o2),
};
let alpha = match alpha_layout.contiguous_offsets() {
None => candle::bail!("alpha has to be contiguous"),
Some((o1, o2)) => alpha.slice(o1..o2),
};
let beta = match beta_layout.contiguous_offsets() {
None => candle::bail!("beta has to be contiguous"),
Some((o1, o2)) => beta.slice(o1..o2),
};
let el = layout.shape().elem_count();
let dims = layout.shape().dims();
let dim_m1 = dims[dims.len() - 1];
let (n_rows, n_cols) = (el / dim_m1, dim_m1);
let cfg = LaunchConfig {
grid_dim: (n_rows as u32, 1, 1),
block_dim: (1024, 1, 1),
shared_mem_bytes: 0,
};
let func = dev.get_or_load_func(&kernel_name::<T>("layernorm"), kernels::REDUCE)?;
// SAFETY: Set later by running the kernel.
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
let params = (&src, &dst, &alpha, &beta, n_cols as i32, self.eps);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(dst)
}
}
use candle::backend::BackendStorage;
let dev = s1.device();
let slice = S { eps: self.eps }.map(&s1.slice, l1, &s2.slice, l2, &s3.slice, l3, dev)?;
let dst = candle::cuda_backend::CudaStorage {
slice,
device: dev.clone(),
};
Ok((dst, l1.shape().clone()))
}
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
s1: &candle::MetalStorage,
l1: &Layout,
s2: &candle::MetalStorage,
l2: &Layout,
s3: &candle::MetalStorage,
l3: &Layout,
) -> Result<(candle::MetalStorage, Shape)> {
use candle::backend::BackendStorage;
let device = s1.device();
let command_buffer = device.command_buffer()?;
let kernels = device.kernels();
let name = match (s1.dtype(), s2.dtype(), s3.dtype()) {
(DType::F32, DType::F32, DType::F32) => "layernorm_f32",
(DType::F16, DType::F16, DType::F16) => "layernorm_f16",
(DType::BF16, DType::BF16, DType::BF16) => "layernorm_bf16",
(dt1, dt2, dt3) => {
candle::bail!("layernorm is not implemented for {dt1:?} {dt2:?} {dt3:?}")
}
};
if !(l1.is_contiguous() && l2.is_contiguous() && l3.is_contiguous()) {
candle::bail!("Non contiguous layernorm is not implemented");
}
let last_dim = l1.dims()[l1.shape().rank() - 1];
let elem_count = l1.shape().elem_count();
let output = device.new_buffer(elem_count, s1.dtype(), "layernorm")?;
candle_metal_kernels::call_layer_norm(
device.metal_device(),
&command_buffer,
kernels,
name,
elem_count,
last_dim,
self.eps,
s1.buffer(),
l1.start_offset() * s1.dtype().size_in_bytes(),
s2.buffer(),
l2.start_offset() * s2.dtype().size_in_bytes(),
s3.buffer(),
l3.start_offset() * s3.dtype().size_in_bytes(),
&output,
)
.map_err(candle::Error::wrap)?;
let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
Ok((newstorage, l1.shape().clone()))
}
}
pub fn layer_norm_slow(x: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result<Tensor> {
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let hidden_size = x.dim(D::Minus1)?;
let x = x.to_dtype(internal_dtype)?;
let x = {
let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
x.broadcast_sub(&mean_x)?
};
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?;
x_normed
.to_dtype(x_dtype)?
.broadcast_mul(alpha)?
.broadcast_add(beta)
}
pub fn layer_norm(xs: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result<Tensor> {
let hidden_size_xs = xs.dim(D::Minus1)?;
let hidden_size_alpha = alpha.dims1()?;
let hidden_size_beta = beta.dims1()?;
if hidden_size_xs != hidden_size_alpha || hidden_size_xs != hidden_size_beta {
candle::bail!(
"shape mismatch in layer-norm src: {:?} alpha: {:?} beta: {:?}",
xs.shape(),
alpha.shape(),
beta.shape()
)
}
xs.apply_op3_no_bwd(alpha, beta, &LayerNorm { eps })
}
// https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html // https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html
pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result<Tensor> { pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result<Tensor> {
let (b_size, c, h, w) = xs.dims4()?; let (b_size, c, h, w) = xs.dims4()?;

View File

@ -13,6 +13,12 @@ fn layer_norm() -> Result<()> {
let device = &Device::Cpu; let device = &Device::Cpu;
let w = Tensor::new(&[3f32], device)?; let w = Tensor::new(&[3f32], device)?;
let b = Tensor::new(&[0.5f32], device)?; let b = Tensor::new(&[0.5f32], device)?;
let ln2 = LayerNorm::new(Tensor::cat(&[&w, &w], 0)?, Tensor::cat(&[&b, &b], 0)?, 1e-8);
let ln3 = LayerNorm::new(
Tensor::cat(&[&w, &w, &w], 0)?,
Tensor::cat(&[&b, &b, &b], 0)?,
1e-8,
);
let ln = LayerNorm::new(w, b, 1e-8); let ln = LayerNorm::new(w, b, 1e-8);
let two = Tensor::new(&[[[2f32]]], device)?; let two = Tensor::new(&[[[2f32]]], device)?;
@ -20,11 +26,11 @@ fn layer_norm() -> Result<()> {
assert_eq!(res.to_vec1::<f32>()?, [0.5f32]); assert_eq!(res.to_vec1::<f32>()?, [0.5f32]);
let inp = Tensor::new(&[[[4f32, 0f32]]], device)?; let inp = Tensor::new(&[[[4f32, 0f32]]], device)?;
let res = ln.forward(&inp)?; let res = ln2.forward(&inp)?;
assert_eq!(res.to_vec3::<f32>()?, [[[3.5f32, -2.5]]]); assert_eq!(res.to_vec3::<f32>()?, [[[3.5f32, -2.5]]]);
let inp = Tensor::new(&[[[1f32, 2., 3.], [4., 5., 6.], [9., 8., 7.]]], device)?; let inp = Tensor::new(&[[[1f32, 2., 3.], [4., 5., 6.], [9., 8., 7.]]], device)?;
let res = ln.forward(&inp)?; let res = ln3.forward(&inp)?;
assert_eq!( assert_eq!(
test_utils::to_vec3_round(&res, 4)?, test_utils::to_vec3_round(&res, 4)?,
[[ [[
@ -35,7 +41,10 @@ fn layer_norm() -> Result<()> {
); );
let mean = (res.sum_keepdim(2)? / 3.0)?; let mean = (res.sum_keepdim(2)? / 3.0)?;
// The average value should be `b`. // The average value should be `b`.
assert_eq!(mean.to_vec3::<f32>()?, [[[0.5], [0.5], [0.5]]]); assert_eq!(
test_utils::to_vec3_round(&mean, 4)?,
[[[0.5], [0.5], [0.5]]]
);
let std = (res.broadcast_sub(&mean)?.sqr()?.sum_keepdim(2)?.sqrt()? / 3.0)?; let std = (res.broadcast_sub(&mean)?.sqr()?.sum_keepdim(2)?.sqrt()? / 3.0)?;
// The standard deviation should be sqrt(`w`). // The standard deviation should be sqrt(`w`).
assert_eq!( assert_eq!(

View File

@ -77,6 +77,32 @@ fn rms_norm(device: &Device) -> Result<()> {
Ok(()) Ok(())
} }
fn layer_norm(device: &Device) -> Result<()> {
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
let tensor = Tensor::new(data, device)?;
let alpha = Tensor::new(&[1f32, 2f32, 3f32], device)?;
let beta = Tensor::new(&[0.5f32, 0f32, -0.2f32], device)?;
let t = candle_nn::ops::layer_norm(&tensor, &alpha, &beta, 1e-5)?;
assert_eq!(
to_vec3_round(&t, 4)?,
&[
[[0.7673, -2.6726, 3.0071], [-0.7247, 0.0, 3.4742]],
[[-0.008, -1.778, 3.991], [1.2071, -2.8284, 1.9213]]
]
);
let t2 = candle_nn::ops::layer_norm_slow(&tensor, &alpha, &beta, 1e-5)?;
assert_eq!(
to_vec3_round(&t2, 4)?,
&[
[[0.7673, -2.6726, 3.0071], [-0.7247, 0.0, 3.4742]],
[[-0.008, -1.778, 3.991], [1.2071, -2.8284, 1.9213]]
]
);
let diff = (t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
assert!(diff < 1e-5);
Ok(())
}
#[test] #[test]
fn softmax_numerical_stability() -> Result<()> { fn softmax_numerical_stability() -> Result<()> {
let dev = &Device::Cpu; let dev = &Device::Cpu;
@ -185,4 +211,5 @@ test_device!(rope, rope_cpu, rope_gpu, rope_metal);
test_device!(rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal); test_device!(rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal);
test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal); test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal);
test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal); test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal);
test_device!(layer_norm, ln_cpu, ln_gpu, ln_metal);
test_device!(sigmoid, sigmoid_cpu, sigmoid_gpu, sigmoid_metal); test_device!(sigmoid, sigmoid_cpu, sigmoid_gpu, sigmoid_metal);

View File

@ -3,6 +3,7 @@ use crate::models::with_tracing::{layer_norm, linear_b, LayerNorm, Linear};
use candle::{IndexOp, Module, Result, Tensor, D}; use candle::{IndexOp, Module, Result, Tensor, D};
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
#[derive(Debug, Clone, serde::Deserialize)]
pub struct Config { pub struct Config {
pub phi_config: PhiConfig, pub phi_config: PhiConfig,
pub vision_config: VisionConfig, pub vision_config: VisionConfig,

View File

@ -56,24 +56,20 @@ impl RotaryEmbedding {
.to_dtype(DType::F32)? .to_dtype(DType::F32)?
.reshape((cfg.max_position_embeddings, 1))?; .reshape((cfg.max_position_embeddings, 1))?;
let freqs = t.matmul(&inv_freq)?; let freqs = t.matmul(&inv_freq)?;
let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
Ok(Self { Ok(Self {
dim, dim,
sin: emb.sin()?, sin: freqs.sin()?,
cos: emb.cos()?, cos: freqs.cos()?,
}) })
} }
fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> { fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (_b_size, _num_heads, seq_len, _headdim) = xs.dims4()?; let (_b_size, _num_heads, seq_len, _headdim) = xs.dims4()?;
let xs_rot = xs.i((.., .., .., ..self.dim))?; let xs_rot = xs.i((.., .., .., ..self.dim))?.contiguous()?;
let xs_pass = xs.i((.., .., .., self.dim..))?; let xs_pass = xs.i((.., .., .., self.dim..))?;
let xs12 = xs_rot.chunk(2, D::Minus1)?;
let (xs1, xs2) = (&xs12[0], &xs12[1]);
let c = self.cos.narrow(0, seqlen_offset, seq_len)?; let c = self.cos.narrow(0, seqlen_offset, seq_len)?;
let s = self.sin.narrow(0, seqlen_offset, seq_len)?; let s = self.sin.narrow(0, seqlen_offset, seq_len)?;
let rotate_half = Tensor::cat(&[&xs2.neg()?, xs1], D::Minus1)?; let xs_rot = candle_nn::rotary_emb::rope(&xs_rot, &c, &s)?;
let xs_rot = (xs_rot.broadcast_mul(&c)? + rotate_half.broadcast_mul(&s)?)?;
Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1) Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)
} }
} }

View File

@ -146,7 +146,7 @@ impl LayerWeights {
}; };
let att = candle_nn::ops::softmax_last_dim(&att)?; let att = candle_nn::ops::softmax_last_dim(&att)?;
// Convert to contiguous as matmul doesn't support strided vs for now. // Convert to contiguous as matmul doesn't support strided vs for now.
att.matmul(&v.contiguous()?)? att.matmul(&v)?
}; };
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
let y = self.attn_output.forward(&y)?; let y = self.attn_output.forward(&y)?;
@ -203,7 +203,6 @@ fn precomput_freqs_cis(
impl ModelWeights { impl ModelWeights {
pub fn from_gguf<R: std::io::Seek + std::io::Read>( pub fn from_gguf<R: std::io::Seek + std::io::Read>(
batch_size: usize,
use_flash_attn: bool, use_flash_attn: bool,
ct: gguf_file::Content, ct: gguf_file::Content,
reader: &mut R, reader: &mut R,
@ -252,12 +251,7 @@ impl ModelWeights {
)?; )?;
let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let kv_cache = KvCache::new( let kv_cache = KvCache::new(2, max_seq_len);
2,
(batch_size, head_count_kv, max_seq_len, head_dim),
DType::F32,
device,
)?;
layers.push(LayerWeights { layers.push(LayerWeights {
attn_qkv: QLinear::new(&ct, reader, &format!("{prefix}.attn_qkv"), device)?, attn_qkv: QLinear::new(&ct, reader, &format!("{prefix}.attn_qkv"), device)?,
attn_output: QLinear::new(&ct, reader, &format!("{prefix}.attn_output"), device)?, attn_output: QLinear::new(&ct, reader, &format!("{prefix}.attn_output"), device)?,