mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Rework the buffer offset logic for metal kernels (#2028)
* Move the metal kernels utils in a separate module. * Use the BufferOffset for unary ops. * Fix clippy lints. * Use the new BufferOffset. * Adapt the binary ops. * Affine. * More ops (powf, elu, cast).
This commit is contained in:
@ -2,8 +2,7 @@ use crate::backend::{BackendDevice, BackendStorage};
|
|||||||
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
|
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose1D, ParamsConvTranspose2D};
|
||||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||||
use candle_metal_kernels::CallConvTranspose2dCfg;
|
use candle_metal_kernels::{BufferOffset, CallConvTranspose2dCfg, Kernels};
|
||||||
use candle_metal_kernels::Kernels;
|
|
||||||
use metal::{Buffer, MTLResourceOptions, NSUInteger};
|
use metal::{Buffer, MTLResourceOptions, NSUInteger};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::ffi::c_void;
|
use std::ffi::c_void;
|
||||||
@ -12,6 +11,12 @@ use std::sync::{Arc, Mutex, RwLock, TryLockError};
|
|||||||
mod device;
|
mod device;
|
||||||
pub use device::{DeviceId, MetalDevice};
|
pub use device::{DeviceId, MetalDevice};
|
||||||
|
|
||||||
|
fn buffer_o<'a>(buffer: &'a Buffer, l: &Layout, dtype: DType) -> BufferOffset<'a> {
|
||||||
|
BufferOffset {
|
||||||
|
buffer,
|
||||||
|
offset_in_bytes: l.start_offset() * dtype.size_in_bytes(),
|
||||||
|
}
|
||||||
|
}
|
||||||
/// Simple way to catch lock error without
|
/// Simple way to catch lock error without
|
||||||
/// depending on T
|
/// depending on T
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
@ -102,7 +107,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
|
|
||||||
let buffer = device.new_buffer(el, self.dtype, "affine")?;
|
let buffer = device.new_buffer(el, self.dtype, "affine")?;
|
||||||
let command_buffer = self.device.command_buffer()?;
|
let command_buffer = self.device.command_buffer()?;
|
||||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
let src = buffer_o(&self.buffer, layout, dtype);
|
||||||
|
if layout.is_contiguous() {
|
||||||
let name = match self.dtype {
|
let name = match self.dtype {
|
||||||
DType::F32 => "affine_f32",
|
DType::F32 => "affine_f32",
|
||||||
DType::F16 => "affine_f16",
|
DType::F16 => "affine_f16",
|
||||||
@ -115,7 +121,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
&device.kernels,
|
&device.kernels,
|
||||||
name,
|
name,
|
||||||
el,
|
el,
|
||||||
&self.buffer,
|
src,
|
||||||
&buffer,
|
&buffer,
|
||||||
mul as f32,
|
mul as f32,
|
||||||
add as f32,
|
add as f32,
|
||||||
@ -134,9 +140,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
&device.kernels,
|
&device.kernels,
|
||||||
name,
|
name,
|
||||||
layout.dims(),
|
layout.dims(),
|
||||||
&self.buffer,
|
src,
|
||||||
layout.stride(),
|
layout.stride(),
|
||||||
layout.start_offset() * dtype.size_in_bytes(),
|
|
||||||
&buffer,
|
&buffer,
|
||||||
mul as f32,
|
mul as f32,
|
||||||
add as f32,
|
add as f32,
|
||||||
@ -155,7 +160,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
|
|
||||||
let buffer = device.new_buffer(el, self.dtype, "powf")?;
|
let buffer = device.new_buffer(el, self.dtype, "powf")?;
|
||||||
let command_buffer = self.device.command_buffer()?;
|
let command_buffer = self.device.command_buffer()?;
|
||||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
let src = buffer_o(&self.buffer, layout, dtype);
|
||||||
|
if layout.is_contiguous() {
|
||||||
let name = match self.dtype {
|
let name = match self.dtype {
|
||||||
DType::F32 => "powf_f32",
|
DType::F32 => "powf_f32",
|
||||||
DType::F16 => "powf_f16",
|
DType::F16 => "powf_f16",
|
||||||
@ -168,7 +174,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
&device.kernels,
|
&device.kernels,
|
||||||
name,
|
name,
|
||||||
el,
|
el,
|
||||||
&self.buffer,
|
src,
|
||||||
&buffer,
|
&buffer,
|
||||||
pow as f32,
|
pow as f32,
|
||||||
)
|
)
|
||||||
@ -186,9 +192,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
&device.kernels,
|
&device.kernels,
|
||||||
name,
|
name,
|
||||||
layout.dims(),
|
layout.dims(),
|
||||||
&self.buffer,
|
src,
|
||||||
layout.stride(),
|
layout.stride(),
|
||||||
layout.start_offset() * dtype.size_in_bytes(),
|
|
||||||
&buffer,
|
&buffer,
|
||||||
pow as f32,
|
pow as f32,
|
||||||
)
|
)
|
||||||
@ -206,7 +211,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
|
|
||||||
let buffer = device.new_buffer(el, self.dtype, "elu")?;
|
let buffer = device.new_buffer(el, self.dtype, "elu")?;
|
||||||
let command_buffer = self.device.command_buffer()?;
|
let command_buffer = self.device.command_buffer()?;
|
||||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
let src = buffer_o(&self.buffer, layout, self.dtype);
|
||||||
|
if layout.is_contiguous() {
|
||||||
let name = match self.dtype {
|
let name = match self.dtype {
|
||||||
DType::F32 => "elu_f32",
|
DType::F32 => "elu_f32",
|
||||||
DType::F16 => "elu_f16",
|
DType::F16 => "elu_f16",
|
||||||
@ -219,7 +225,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
&device.kernels,
|
&device.kernels,
|
||||||
name,
|
name,
|
||||||
el,
|
el,
|
||||||
&self.buffer,
|
src,
|
||||||
&buffer,
|
&buffer,
|
||||||
alpha as f32,
|
alpha as f32,
|
||||||
)
|
)
|
||||||
@ -237,9 +243,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
&device.kernels,
|
&device.kernels,
|
||||||
name,
|
name,
|
||||||
layout.dims(),
|
layout.dims(),
|
||||||
&self.buffer,
|
src,
|
||||||
layout.stride(),
|
layout.stride(),
|
||||||
layout.start_offset() * dtype.size_in_bytes(),
|
|
||||||
&buffer,
|
&buffer,
|
||||||
alpha as f32,
|
alpha as f32,
|
||||||
)
|
)
|
||||||
@ -344,7 +349,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
let el_count = shape.elem_count();
|
let el_count = shape.elem_count();
|
||||||
let buffer = device.new_buffer(el_count, dtype, "todtype")?;
|
let buffer = device.new_buffer(el_count, dtype, "todtype")?;
|
||||||
let command_buffer = device.command_buffer()?;
|
let command_buffer = device.command_buffer()?;
|
||||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
let src = buffer_o(&self.buffer, layout, self.dtype);
|
||||||
|
if layout.is_contiguous() {
|
||||||
let kernel_name = match (self.dtype, dtype) {
|
let kernel_name = match (self.dtype, dtype) {
|
||||||
(DType::U32, DType::BF16) => "cast_u32_bf16",
|
(DType::U32, DType::BF16) => "cast_u32_bf16",
|
||||||
(DType::U32, DType::F16) => "cast_u32_f16",
|
(DType::U32, DType::F16) => "cast_u32_f16",
|
||||||
@ -392,8 +398,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
&device.kernels,
|
&device.kernels,
|
||||||
kernel_name,
|
kernel_name,
|
||||||
el_count,
|
el_count,
|
||||||
&self.buffer,
|
src,
|
||||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
|
||||||
&buffer,
|
&buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
@ -420,9 +425,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
&device.kernels,
|
&device.kernels,
|
||||||
kernel_name,
|
kernel_name,
|
||||||
layout.dims(),
|
layout.dims(),
|
||||||
&self.buffer,
|
src,
|
||||||
layout.stride(),
|
layout.stride(),
|
||||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
|
||||||
&buffer,
|
&buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
@ -439,7 +443,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?;
|
let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?;
|
||||||
let command_buffer = device.command_buffer()?;
|
let command_buffer = device.command_buffer()?;
|
||||||
command_buffer.set_label(B::KERNEL);
|
command_buffer.set_label(B::KERNEL);
|
||||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
let src = buffer_o(&self.buffer, layout, self.dtype);
|
||||||
|
if layout.is_contiguous() {
|
||||||
use candle_metal_kernels::unary::contiguous;
|
use candle_metal_kernels::unary::contiguous;
|
||||||
|
|
||||||
let kernel_name = match (B::KERNEL, dtype) {
|
let kernel_name = match (B::KERNEL, dtype) {
|
||||||
@ -511,7 +516,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
&device.kernels,
|
&device.kernels,
|
||||||
kernel_name,
|
kernel_name,
|
||||||
el_count,
|
el_count,
|
||||||
&self.buffer,
|
src,
|
||||||
&buffer,
|
&buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
@ -556,17 +561,16 @@ impl BackendStorage for MetalStorage {
|
|||||||
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
|
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
let dst = BufferOffset::zero_offset(&buffer);
|
||||||
candle_metal_kernels::call_unary_strided(
|
candle_metal_kernels::call_unary_strided(
|
||||||
&device.device,
|
&device.device,
|
||||||
&command_buffer,
|
&command_buffer,
|
||||||
&device.kernels,
|
&device.kernels,
|
||||||
kernel_name,
|
kernel_name,
|
||||||
layout.dims(),
|
layout.dims(),
|
||||||
&self.buffer,
|
src,
|
||||||
layout.stride(),
|
layout.stride(),
|
||||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
dst,
|
||||||
&buffer,
|
|
||||||
0,
|
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
}
|
}
|
||||||
@ -1358,17 +1362,20 @@ impl BackendStorage for MetalStorage {
|
|||||||
DType::U8 => candle_metal_kernels::unary::strided::copy::U8,
|
DType::U8 => candle_metal_kernels::unary::strided::copy::U8,
|
||||||
dtype => crate::bail!("Metal copy_strided {dtype:?} not implemented"),
|
dtype => crate::bail!("Metal copy_strided {dtype:?} not implemented"),
|
||||||
};
|
};
|
||||||
|
let src = buffer_o(&self.buffer, src_l, self.dtype);
|
||||||
|
let dst = BufferOffset {
|
||||||
|
buffer: &dst.buffer,
|
||||||
|
offset_in_bytes: dst_offset * dst.dtype.size_in_bytes(),
|
||||||
|
};
|
||||||
candle_metal_kernels::call_unary_strided(
|
candle_metal_kernels::call_unary_strided(
|
||||||
&self.device.device,
|
&self.device.device,
|
||||||
&command_buffer,
|
&command_buffer,
|
||||||
&self.device.kernels,
|
&self.device.kernels,
|
||||||
kernel_name,
|
kernel_name,
|
||||||
src_l.dims(),
|
src_l.dims(),
|
||||||
&self.buffer,
|
src,
|
||||||
src_l.stride(),
|
src_l.stride(),
|
||||||
src_l.start_offset() * self.dtype.size_in_bytes(),
|
dst,
|
||||||
&dst.buffer,
|
|
||||||
dst_offset * dst.dtype.size_in_bytes(),
|
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
command_buffer.set_label("copy_strided");
|
command_buffer.set_label("copy_strided");
|
||||||
@ -1402,10 +1409,9 @@ impl MetalStorage {
|
|||||||
let shape = lhs_l.shape();
|
let shape = lhs_l.shape();
|
||||||
let el_count = shape.elem_count();
|
let el_count = shape.elem_count();
|
||||||
let command_buffer = device.command_buffer()?;
|
let command_buffer = device.command_buffer()?;
|
||||||
let (buffer, dtype) = if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
|
let lhs = buffer_o(&self.buffer, lhs_l, self.dtype);
|
||||||
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
|
let rhs = buffer_o(&rhs.buffer, rhs_l, rhs.dtype);
|
||||||
&& &op[..1] != "b"
|
let (buffer, dtype) = if lhs_l.is_contiguous() && rhs_l.is_contiguous() && &op[..1] != "b" {
|
||||||
{
|
|
||||||
use candle_metal_kernels::binary::contiguous;
|
use candle_metal_kernels::binary::contiguous;
|
||||||
|
|
||||||
let (kernel_name, dtype) = match (op, self.dtype) {
|
let (kernel_name, dtype) = match (op, self.dtype) {
|
||||||
@ -1486,8 +1492,8 @@ impl MetalStorage {
|
|||||||
&device.kernels,
|
&device.kernels,
|
||||||
kernel_name,
|
kernel_name,
|
||||||
el_count,
|
el_count,
|
||||||
&self.buffer,
|
lhs,
|
||||||
&rhs.buffer,
|
rhs,
|
||||||
&buffer,
|
&buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
@ -1585,12 +1591,10 @@ impl MetalStorage {
|
|||||||
&device.kernels,
|
&device.kernels,
|
||||||
kernel_name,
|
kernel_name,
|
||||||
lhs_l.dims(),
|
lhs_l.dims(),
|
||||||
&self.buffer,
|
lhs,
|
||||||
lhs_l.stride(),
|
lhs_l.stride(),
|
||||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
rhs,
|
||||||
&rhs.buffer,
|
|
||||||
rhs_l.stride(),
|
rhs_l.stride(),
|
||||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
|
||||||
&buffer,
|
&buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
|
@ -1,11 +1,15 @@
|
|||||||
use metal::{
|
use metal::{
|
||||||
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
|
Buffer, CommandBufferRef, CompileOptions, ComputePipelineState, Device, Function,
|
||||||
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
|
FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
|
||||||
};
|
};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::ffi::c_void;
|
use std::ffi::c_void;
|
||||||
use std::sync::RwLock;
|
use std::sync::RwLock;
|
||||||
|
|
||||||
|
mod utils;
|
||||||
|
pub use utils::BufferOffset;
|
||||||
|
use utils::{get_block_dims, linear_split};
|
||||||
|
|
||||||
const AFFINE: &str = include_str!("affine.metal");
|
const AFFINE: &str = include_str!("affine.metal");
|
||||||
const INDEXING: &str = include_str!("indexing.metal");
|
const INDEXING: &str = include_str!("indexing.metal");
|
||||||
const UNARY: &str = include_str!("unary.metal");
|
const UNARY: &str = include_str!("unary.metal");
|
||||||
@ -18,138 +22,6 @@ const RANDOM: &str = include_str!("random.metal");
|
|||||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
const QUANTIZED: &str = include_str!("quantized.metal");
|
||||||
|
|
||||||
/// Most kernels apply similarly across the tensors
|
|
||||||
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
|
|
||||||
/// actual total buffer length).
|
|
||||||
/// Then kernels can just do their op on their single point in the buffer.
|
|
||||||
fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
|
|
||||||
let size = length as u64;
|
|
||||||
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size);
|
|
||||||
let count = (size + width - 1) / width;
|
|
||||||
let thread_group_count = MTLSize {
|
|
||||||
width: count,
|
|
||||||
height: 1,
|
|
||||||
depth: 1,
|
|
||||||
};
|
|
||||||
|
|
||||||
let thread_group_size = MTLSize {
|
|
||||||
width,
|
|
||||||
height: 1,
|
|
||||||
depth: 1,
|
|
||||||
};
|
|
||||||
(thread_group_count, thread_group_size)
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96
|
|
||||||
fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
|
|
||||||
let mut pows0 = 0u64;
|
|
||||||
let mut pows1 = 0u64;
|
|
||||||
let mut pows2 = 0u64;
|
|
||||||
let mut sum = 0u64;
|
|
||||||
loop {
|
|
||||||
let presum = sum;
|
|
||||||
// Check all the pows
|
|
||||||
if dim0 >= (1 << (pows0 + 1)) {
|
|
||||||
pows0 += 1;
|
|
||||||
sum += 1;
|
|
||||||
}
|
|
||||||
if sum == 10 {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if dim1 >= (1 << (pows1 + 1)) {
|
|
||||||
pows1 += 1;
|
|
||||||
sum += 1;
|
|
||||||
}
|
|
||||||
if sum == 10 {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if dim2 >= (1 << (pows2 + 1)) {
|
|
||||||
pows2 += 1;
|
|
||||||
sum += 1;
|
|
||||||
}
|
|
||||||
if sum == presum || sum == 10 {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
MTLSize {
|
|
||||||
width: 1 << pows0,
|
|
||||||
height: 1 << pows1,
|
|
||||||
depth: 1 << pows2,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {
|
|
||||||
<P as EncoderParam>::set_param(encoder, position, data)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Helper functions to create the various objects on the compute command encoder
|
|
||||||
/// on a single line.
|
|
||||||
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
|
|
||||||
trait EncoderParam {
|
|
||||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
|
|
||||||
}
|
|
||||||
macro_rules! primitive {
|
|
||||||
($type:ty) => {
|
|
||||||
impl EncoderParam for $type {
|
|
||||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
|
||||||
encoder.set_bytes(
|
|
||||||
position,
|
|
||||||
core::mem::size_of::<$type>() as u64,
|
|
||||||
&data as *const $type as *const c_void,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
primitive!(bool);
|
|
||||||
primitive!(usize);
|
|
||||||
primitive!(i32);
|
|
||||||
primitive!(i64);
|
|
||||||
primitive!(u32);
|
|
||||||
primitive!(u64);
|
|
||||||
primitive!(f32);
|
|
||||||
|
|
||||||
impl<T> EncoderParam for &[T] {
|
|
||||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
|
||||||
encoder.set_bytes(
|
|
||||||
position,
|
|
||||||
core::mem::size_of_val(data) as u64,
|
|
||||||
data.as_ptr() as *const c_void,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl EncoderParam for &Buffer {
|
|
||||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
|
||||||
encoder.set_buffer(position, Some(data), 0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
impl EncoderParam for (&Buffer, usize) {
|
|
||||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
|
||||||
encoder.set_buffer(position, Some(data.0), data.1 as u64);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
impl EncoderParam for &mut Buffer {
|
|
||||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
|
||||||
encoder.set_buffer(position, Some(data), 0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
impl EncoderParam for (&mut Buffer, usize) {
|
|
||||||
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
|
||||||
encoder.set_buffer(position, Some(data.0), data.1 as u64);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
macro_rules! set_params {
|
|
||||||
($encoder:ident, ($($param:expr),+)) => (
|
|
||||||
let mut _index = 0;
|
|
||||||
$(
|
|
||||||
set_param($encoder, _index, $param);
|
|
||||||
_index += 1;
|
|
||||||
)*
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
pub enum Source {
|
pub enum Source {
|
||||||
Affine,
|
Affine,
|
||||||
@ -273,6 +145,12 @@ pub struct Kernels {
|
|||||||
pipelines: RwLock<Pipelines>,
|
pipelines: RwLock<Pipelines>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Default for Kernels {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Kernels {
|
impl Kernels {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
let libraries = RwLock::new(Libraries::new());
|
let libraries = RwLock::new(Libraries::new());
|
||||||
@ -396,17 +274,17 @@ pub fn call_unary_contiguous(
|
|||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
kernel_name: unary::contiguous::Kernel,
|
kernel_name: unary::contiguous::Kernel,
|
||||||
length: usize,
|
length: usize,
|
||||||
input: &Buffer,
|
input: BufferOffset,
|
||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (length, input, output));
|
set_params!(encoder, (length, &input, output));
|
||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
@ -463,11 +341,9 @@ pub fn call_unary_strided(
|
|||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: unary::strided::Kernel,
|
name: unary::strided::Kernel,
|
||||||
shape: &[usize],
|
shape: &[usize],
|
||||||
input: &Buffer,
|
input: BufferOffset,
|
||||||
strides: &[usize],
|
strides: &[usize],
|
||||||
offset: usize,
|
output: BufferOffset,
|
||||||
output: &Buffer,
|
|
||||||
output_offset: usize,
|
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
|
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
|
||||||
|
|
||||||
@ -476,23 +352,13 @@ pub fn call_unary_strided(
|
|||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
let length: usize = shape.iter().product();
|
let length: usize = shape.iter().product();
|
||||||
set_params!(
|
set_params!(encoder, (length, num_dims, shape, strides, &input, &output));
|
||||||
encoder,
|
|
||||||
(
|
|
||||||
length,
|
|
||||||
num_dims,
|
|
||||||
shape,
|
|
||||||
strides,
|
|
||||||
(input, offset),
|
|
||||||
(output, output_offset)
|
|
||||||
)
|
|
||||||
);
|
|
||||||
|
|
||||||
let width: usize = shape.iter().product();
|
let width: usize = shape.iter().product();
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
||||||
|
|
||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -505,8 +371,8 @@ pub fn call_binary_contiguous(
|
|||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
kernel_name: binary::contiguous::Kernel,
|
kernel_name: binary::contiguous::Kernel,
|
||||||
length: usize,
|
length: usize,
|
||||||
left: &Buffer,
|
left: BufferOffset,
|
||||||
right: &Buffer,
|
right: BufferOffset,
|
||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
|
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
|
||||||
@ -514,12 +380,12 @@ pub fn call_binary_contiguous(
|
|||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (length, left, right, output));
|
set_params!(encoder, (length, &left, &right, output));
|
||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||||
|
|
||||||
encoder.use_resource(left, metal::MTLResourceUsage::Read);
|
encoder.use_resource(left.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(right, metal::MTLResourceUsage::Read);
|
encoder.use_resource(right.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
@ -533,12 +399,10 @@ pub fn call_binary_strided(
|
|||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: binary::strided::Kernel,
|
name: binary::strided::Kernel,
|
||||||
shape: &[usize],
|
shape: &[usize],
|
||||||
left_input: &Buffer,
|
left_input: BufferOffset,
|
||||||
left_strides: &[usize],
|
left_strides: &[usize],
|
||||||
left_offset: usize,
|
right_input: BufferOffset,
|
||||||
right_input: &Buffer,
|
|
||||||
right_strides: &[usize],
|
right_strides: &[usize],
|
||||||
right_offset: usize,
|
|
||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
|
let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
|
||||||
@ -558,16 +422,16 @@ pub fn call_binary_strided(
|
|||||||
shape,
|
shape,
|
||||||
left_strides,
|
left_strides,
|
||||||
right_strides,
|
right_strides,
|
||||||
(left_input, left_offset),
|
&left_input,
|
||||||
(right_input, right_offset),
|
&right_input,
|
||||||
output
|
output
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
||||||
|
|
||||||
encoder.use_resource(left_input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(left_input.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(right_input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(right_input.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
@ -581,8 +445,7 @@ pub fn call_cast_contiguous(
|
|||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
kernel_name: &'static str,
|
kernel_name: &'static str,
|
||||||
length: usize,
|
length: usize,
|
||||||
input: &Buffer,
|
input: BufferOffset,
|
||||||
input_offset: usize,
|
|
||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||||
@ -590,10 +453,10 @@ pub fn call_cast_contiguous(
|
|||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (length, (input, input_offset), output));
|
set_params!(encoder, (length, &input, output));
|
||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
@ -607,9 +470,8 @@ pub fn call_cast_strided(
|
|||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
kernel_name: &'static str,
|
kernel_name: &'static str,
|
||||||
shape: &[usize],
|
shape: &[usize],
|
||||||
input: &Buffer,
|
input: BufferOffset,
|
||||||
input_strides: &[usize],
|
input_strides: &[usize],
|
||||||
input_offset: usize,
|
|
||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||||
@ -621,25 +483,19 @@ pub fn call_cast_strided(
|
|||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
encoder,
|
encoder,
|
||||||
(
|
(length, shape.len(), shape, input_strides, &input, output)
|
||||||
length,
|
|
||||||
shape.len(),
|
|
||||||
shape,
|
|
||||||
input_strides,
|
|
||||||
(input, input_offset),
|
|
||||||
output
|
|
||||||
)
|
|
||||||
);
|
);
|
||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||||
|
|
||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_reduce_contiguous(
|
pub fn call_reduce_contiguous(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
@ -687,6 +543,7 @@ pub fn call_reduce_contiguous(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_reduce_strided(
|
pub fn call_reduce_strided(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
@ -985,7 +842,7 @@ pub fn call_affine(
|
|||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
size: usize,
|
size: usize,
|
||||||
input: &Buffer,
|
input: BufferOffset,
|
||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
mul: f32,
|
mul: f32,
|
||||||
add: f32,
|
add: f32,
|
||||||
@ -995,10 +852,10 @@ pub fn call_affine(
|
|||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (size, mul, add, input, output));
|
set_params!(encoder, (size, mul, add, &input, output));
|
||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
@ -1012,9 +869,8 @@ pub fn call_affine_strided(
|
|||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
shape: &[usize],
|
shape: &[usize],
|
||||||
input: &Buffer,
|
input: BufferOffset,
|
||||||
input_stride: &[usize],
|
input_stride: &[usize],
|
||||||
input_offset: usize,
|
|
||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
mul: f32,
|
mul: f32,
|
||||||
add: f32,
|
add: f32,
|
||||||
@ -1034,13 +890,13 @@ pub fn call_affine_strided(
|
|||||||
input_stride,
|
input_stride,
|
||||||
mul,
|
mul,
|
||||||
add,
|
add,
|
||||||
(input, input_offset),
|
&input,
|
||||||
output
|
output
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
@ -1054,7 +910,7 @@ pub fn call_powf(
|
|||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
size: usize,
|
size: usize,
|
||||||
input: &Buffer,
|
input: BufferOffset,
|
||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
mul: f32,
|
mul: f32,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
@ -1063,10 +919,10 @@ pub fn call_powf(
|
|||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (size, mul, input, output));
|
set_params!(encoder, (size, mul, &input, output));
|
||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
@ -1080,9 +936,8 @@ pub fn call_powf_strided(
|
|||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
shape: &[usize],
|
shape: &[usize],
|
||||||
input: &Buffer,
|
input: BufferOffset,
|
||||||
input_stride: &[usize],
|
input_stride: &[usize],
|
||||||
input_offset: usize,
|
|
||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
mul: f32,
|
mul: f32,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
@ -1094,19 +949,11 @@ pub fn call_powf_strided(
|
|||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
encoder,
|
encoder,
|
||||||
(
|
(size, shape.len(), shape, input_stride, mul, &input, output)
|
||||||
size,
|
|
||||||
shape.len(),
|
|
||||||
shape,
|
|
||||||
input_stride,
|
|
||||||
mul,
|
|
||||||
(input, input_offset),
|
|
||||||
output
|
|
||||||
)
|
|
||||||
);
|
);
|
||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
@ -1120,7 +967,7 @@ pub fn call_elu(
|
|||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
size: usize,
|
size: usize,
|
||||||
input: &Buffer,
|
input: BufferOffset,
|
||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
mul: f32,
|
mul: f32,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
@ -1129,10 +976,10 @@ pub fn call_elu(
|
|||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (size, mul, input, output));
|
set_params!(encoder, (size, mul, &input, output));
|
||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
@ -1146,9 +993,8 @@ pub fn call_elu_strided(
|
|||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
shape: &[usize],
|
shape: &[usize],
|
||||||
input: &Buffer,
|
input: BufferOffset,
|
||||||
input_stride: &[usize],
|
input_stride: &[usize],
|
||||||
input_offset: usize,
|
|
||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
mul: f32,
|
mul: f32,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
@ -1160,25 +1006,18 @@ pub fn call_elu_strided(
|
|||||||
|
|
||||||
set_params!(
|
set_params!(
|
||||||
encoder,
|
encoder,
|
||||||
(
|
(size, shape.len(), shape, input_stride, mul, &input, output)
|
||||||
size,
|
|
||||||
shape.len(),
|
|
||||||
shape,
|
|
||||||
input_stride,
|
|
||||||
mul,
|
|
||||||
(input, input_offset),
|
|
||||||
output
|
|
||||||
)
|
|
||||||
);
|
);
|
||||||
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
|
||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_where_cond_strided(
|
pub fn call_where_cond_strided(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
@ -1334,6 +1173,7 @@ pub fn call_gather(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_scatter_add(
|
pub fn call_scatter_add(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
@ -1384,6 +1224,7 @@ pub fn call_scatter_add(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_index_add(
|
pub fn call_index_add(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
@ -1910,6 +1751,7 @@ pub enum GgmlDType {
|
|||||||
F32,
|
F32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_quantized_matmul_t(
|
pub fn call_quantized_matmul_t(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
@ -1925,16 +1767,16 @@ pub fn call_quantized_matmul_t(
|
|||||||
let ne00 = k as i64;
|
let ne00 = k as i64;
|
||||||
let ne01 = n as i64;
|
let ne01 = n as i64;
|
||||||
let ne02 = b as i64;
|
let ne02 = b as i64;
|
||||||
let ne03 = 1 as i64;
|
let ne03 = 1i64;
|
||||||
|
|
||||||
let nb00 = 0i64;
|
let nb00 = 0i64;
|
||||||
let nb01 = 0 as i64;
|
let nb01 = 0i64;
|
||||||
let nb02 = 0 as i64;
|
let nb02 = 0i64;
|
||||||
|
|
||||||
let ne10 = k as i64;
|
let ne10 = k as i64;
|
||||||
let ne11 = m as i64;
|
let ne11 = m as i64;
|
||||||
let ne12 = b as i64;
|
let ne12 = b as i64;
|
||||||
let ne13 = 1 as i64;
|
let ne13 = 1i64;
|
||||||
|
|
||||||
let nb10 = 0i64;
|
let nb10 = 0i64;
|
||||||
let nb11 = 0i64;
|
let nb11 = 0i64;
|
||||||
@ -2169,6 +2011,7 @@ pub struct CallConvTranspose2dCfg<'a> {
|
|||||||
pub kernel_offset: usize,
|
pub kernel_offset: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_conv_transpose2d(
|
pub fn call_conv_transpose2d(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
|
@ -12,7 +12,7 @@ fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
|||||||
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
|
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
|
||||||
let options = MTLResourceOptions::StorageModeManaged;
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
let ptr = data.as_ptr() as *const c_void;
|
let ptr = data.as_ptr() as *const c_void;
|
||||||
let size = (data.len() * std::mem::size_of::<T>()) as u64;
|
let size = std::mem::size_of_val(data) as u64;
|
||||||
device.new_buffer_with_data(ptr, size, options)
|
device.new_buffer_with_data(ptr, size, options)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -41,6 +41,10 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
|
|||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
let input = new_buffer(&device, v);
|
let input = new_buffer(&device, v);
|
||||||
|
let input = BufferOffset {
|
||||||
|
buffer: &input,
|
||||||
|
offset_in_bytes: 0,
|
||||||
|
};
|
||||||
let output = new_buffer(&device, v);
|
let output = new_buffer(&device, v);
|
||||||
call_unary_contiguous(
|
call_unary_contiguous(
|
||||||
&device,
|
&device,
|
||||||
@ -48,7 +52,7 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
|
|||||||
&kernels,
|
&kernels,
|
||||||
name,
|
name,
|
||||||
v.len(),
|
v.len(),
|
||||||
&input,
|
input,
|
||||||
&output,
|
&output,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -72,8 +76,8 @@ fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> V
|
|||||||
&kernels,
|
&kernels,
|
||||||
name,
|
name,
|
||||||
x.len(),
|
x.len(),
|
||||||
&left,
|
BufferOffset::zero_offset(&left),
|
||||||
&right,
|
BufferOffset::zero_offset(&right),
|
||||||
&output,
|
&output,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -93,7 +97,15 @@ fn run_strided<T: Clone>(
|
|||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
let input = new_buffer(&device, v);
|
let input = new_buffer(&device, v);
|
||||||
let output = new_buffer(&device, v);
|
let input = BufferOffset {
|
||||||
|
buffer: &input,
|
||||||
|
offset_in_bytes: offset,
|
||||||
|
};
|
||||||
|
let output_b = new_buffer(&device, v);
|
||||||
|
let output = BufferOffset {
|
||||||
|
buffer: &output_b,
|
||||||
|
offset_in_bytes: 0,
|
||||||
|
};
|
||||||
let kernels = Kernels::new();
|
let kernels = Kernels::new();
|
||||||
call_unary_strided(
|
call_unary_strided(
|
||||||
&device,
|
&device,
|
||||||
@ -101,16 +113,14 @@ fn run_strided<T: Clone>(
|
|||||||
&kernels,
|
&kernels,
|
||||||
kernel,
|
kernel,
|
||||||
shape,
|
shape,
|
||||||
&input,
|
input,
|
||||||
strides,
|
strides,
|
||||||
offset,
|
output,
|
||||||
&output,
|
|
||||||
0,
|
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
command_buffer.wait_until_completed();
|
command_buffer.wait_until_completed();
|
||||||
read_to_vec(&output, v.len())
|
read_to_vec(&output_b, v.len())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -308,8 +318,7 @@ fn run_cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
|||||||
&kernels,
|
&kernels,
|
||||||
name,
|
name,
|
||||||
v.len(),
|
v.len(),
|
||||||
&input,
|
BufferOffset::zero_offset(&input),
|
||||||
0,
|
|
||||||
&output,
|
&output,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -521,7 +530,7 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
|||||||
&kernels,
|
&kernels,
|
||||||
"affine_f32",
|
"affine_f32",
|
||||||
size,
|
size,
|
||||||
&input,
|
BufferOffset::zero_offset(&input),
|
||||||
&output,
|
&output,
|
||||||
mul as f32,
|
mul as f32,
|
||||||
add as f32,
|
add as f32,
|
||||||
@ -554,9 +563,8 @@ fn run_affine_strided<T: Clone>(
|
|||||||
&kernels,
|
&kernels,
|
||||||
"affine_f32_strided",
|
"affine_f32_strided",
|
||||||
shape,
|
shape,
|
||||||
&input,
|
BufferOffset::zero_offset(&input),
|
||||||
strides,
|
strides,
|
||||||
0,
|
|
||||||
&output,
|
&output,
|
||||||
mul as f32,
|
mul as f32,
|
||||||
add as f32,
|
add as f32,
|
||||||
@ -633,7 +641,7 @@ fn index_select_strided() {
|
|||||||
fn index_select_f16() {
|
fn index_select_f16() {
|
||||||
let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
|
let embedding: Vec<_> = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|x| f16::from_f32(x))
|
.map(f16::from_f32)
|
||||||
.collect();
|
.collect();
|
||||||
let shape = [5, 2];
|
let shape = [5, 2];
|
||||||
let stride = [2, 1];
|
let stride = [2, 1];
|
||||||
@ -700,8 +708,8 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
|||||||
|
|
||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
let embeddings_buffer = new_buffer(&device, &embeddings);
|
let embeddings_buffer = new_buffer(&device, embeddings);
|
||||||
let ids_buffer = new_buffer(&device, &ids);
|
let ids_buffer = new_buffer(&device, ids);
|
||||||
|
|
||||||
let left_size: usize = shape[..dim].iter().product();
|
let left_size: usize = shape[..dim].iter().product();
|
||||||
let right_size: usize = shape[dim + 1..].iter().product();
|
let right_size: usize = shape[dim + 1..].iter().product();
|
||||||
@ -711,7 +719,7 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
|
|||||||
let kernels = Kernels::new();
|
let kernels = Kernels::new();
|
||||||
call_index_select(
|
call_index_select(
|
||||||
&device,
|
&device,
|
||||||
&command_buffer,
|
command_buffer,
|
||||||
&kernels,
|
&kernels,
|
||||||
name,
|
name,
|
||||||
shape,
|
shape,
|
||||||
@ -746,8 +754,8 @@ fn run_index_select_strided<T: Clone, I: Clone + std::fmt::Debug>(
|
|||||||
|
|
||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
let embeddings_buffer = new_buffer(&device, &embeddings);
|
let embeddings_buffer = new_buffer(&device, embeddings);
|
||||||
let ids_buffer = new_buffer(&device, &ids);
|
let ids_buffer = new_buffer(&device, ids);
|
||||||
|
|
||||||
let left_size: usize = shape[..dim].iter().product();
|
let left_size: usize = shape[..dim].iter().product();
|
||||||
let right_size: usize = shape[dim + 1..].iter().product();
|
let right_size: usize = shape[dim + 1..].iter().product();
|
||||||
@ -757,7 +765,7 @@ fn run_index_select_strided<T: Clone, I: Clone + std::fmt::Debug>(
|
|||||||
let kernels = Kernels::new();
|
let kernels = Kernels::new();
|
||||||
call_index_select(
|
call_index_select(
|
||||||
&device,
|
&device,
|
||||||
&command_buffer,
|
command_buffer,
|
||||||
&kernels,
|
&kernels,
|
||||||
name,
|
name,
|
||||||
shape,
|
shape,
|
||||||
@ -931,6 +939,7 @@ fn softmax() {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
fn run_where_cond<I: Clone, T: Clone>(
|
fn run_where_cond<I: Clone, T: Clone>(
|
||||||
shape: &[usize],
|
shape: &[usize],
|
||||||
cond: &[I],
|
cond: &[I],
|
||||||
@ -1148,7 +1157,7 @@ fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b:
|
|||||||
#[test]
|
#[test]
|
||||||
fn random() {
|
fn random() {
|
||||||
fn calc_mean(data: &[f32]) -> f32 {
|
fn calc_mean(data: &[f32]) -> f32 {
|
||||||
let sum = data.iter().sum::<f32>() as f32;
|
let sum = data.iter().sum::<f32>();
|
||||||
let count = data.len();
|
let count = data.len();
|
||||||
assert!(count > 0);
|
assert!(count > 0);
|
||||||
sum / count as f32
|
sum / count as f32
|
||||||
@ -1162,7 +1171,7 @@ fn random() {
|
|||||||
let variance = data
|
let variance = data
|
||||||
.iter()
|
.iter()
|
||||||
.map(|value| {
|
.map(|value| {
|
||||||
let diff = mean - (*value as f32);
|
let diff = mean - *value;
|
||||||
diff * diff
|
diff * diff
|
||||||
})
|
})
|
||||||
.sum::<f32>()
|
.sum::<f32>()
|
||||||
@ -1787,6 +1796,7 @@ fn avg_pool2d_u32() {
|
|||||||
assert_eq!(results, expected);
|
assert_eq!(results, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
fn run_conv_transpose1d<T: Clone>(
|
fn run_conv_transpose1d<T: Clone>(
|
||||||
input: &[T],
|
input: &[T],
|
||||||
input_shape: &[usize],
|
input_shape: &[usize],
|
||||||
|
162
candle-metal-kernels/src/utils.rs
Normal file
162
candle-metal-kernels/src/utils.rs
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
use metal::{Buffer, ComputeCommandEncoderRef, ComputePipelineState, MTLSize};
|
||||||
|
use std::ffi::c_void;
|
||||||
|
|
||||||
|
/// Most kernels apply similarly across the tensors
|
||||||
|
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
|
||||||
|
/// actual total buffer length).
|
||||||
|
/// Then kernels can just do their op on their single point in the buffer.
|
||||||
|
pub(crate) fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (MTLSize, MTLSize) {
|
||||||
|
let size = length as u64;
|
||||||
|
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size);
|
||||||
|
let count = (size + width - 1) / width;
|
||||||
|
let thread_group_count = MTLSize {
|
||||||
|
width: count,
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
let thread_group_size = MTLSize {
|
||||||
|
width,
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
(thread_group_count, thread_group_size)
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96
|
||||||
|
pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
|
||||||
|
let mut pows0 = 0u64;
|
||||||
|
let mut pows1 = 0u64;
|
||||||
|
let mut pows2 = 0u64;
|
||||||
|
let mut sum = 0u64;
|
||||||
|
loop {
|
||||||
|
let presum = sum;
|
||||||
|
// Check all the pows
|
||||||
|
if dim0 >= (1 << (pows0 + 1)) {
|
||||||
|
pows0 += 1;
|
||||||
|
sum += 1;
|
||||||
|
}
|
||||||
|
if sum == 10 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if dim1 >= (1 << (pows1 + 1)) {
|
||||||
|
pows1 += 1;
|
||||||
|
sum += 1;
|
||||||
|
}
|
||||||
|
if sum == 10 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if dim2 >= (1 << (pows2 + 1)) {
|
||||||
|
pows2 += 1;
|
||||||
|
sum += 1;
|
||||||
|
}
|
||||||
|
if sum == presum || sum == 10 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MTLSize {
|
||||||
|
width: 1 << pows0,
|
||||||
|
height: 1 << pows1,
|
||||||
|
depth: 1 << pows2,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn set_param<P: EncoderParam>(
|
||||||
|
encoder: &ComputeCommandEncoderRef,
|
||||||
|
position: u64,
|
||||||
|
data: P,
|
||||||
|
) {
|
||||||
|
<P as EncoderParam>::set_param(encoder, position, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper functions to create the various objects on the compute command encoder
|
||||||
|
/// on a single line.
|
||||||
|
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
|
||||||
|
pub(crate) trait EncoderParam {
|
||||||
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
|
||||||
|
}
|
||||||
|
macro_rules! primitive {
|
||||||
|
($type:ty) => {
|
||||||
|
impl EncoderParam for $type {
|
||||||
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||||
|
encoder.set_bytes(
|
||||||
|
position,
|
||||||
|
core::mem::size_of::<$type>() as u64,
|
||||||
|
&data as *const $type as *const c_void,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
primitive!(bool);
|
||||||
|
primitive!(usize);
|
||||||
|
primitive!(i32);
|
||||||
|
primitive!(i64);
|
||||||
|
primitive!(u32);
|
||||||
|
primitive!(u64);
|
||||||
|
primitive!(f32);
|
||||||
|
|
||||||
|
pub struct BufferOffset<'a> {
|
||||||
|
pub buffer: &'a Buffer,
|
||||||
|
pub offset_in_bytes: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> BufferOffset<'a> {
|
||||||
|
pub fn zero_offset(buffer: &'a Buffer) -> Self {
|
||||||
|
Self {
|
||||||
|
buffer,
|
||||||
|
offset_in_bytes: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> EncoderParam for &[T] {
|
||||||
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||||
|
encoder.set_bytes(
|
||||||
|
position,
|
||||||
|
core::mem::size_of_val(data) as u64,
|
||||||
|
data.as_ptr() as *const c_void,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EncoderParam for &Buffer {
|
||||||
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||||
|
encoder.set_buffer(position, Some(data), 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EncoderParam for (&Buffer, usize) {
|
||||||
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||||
|
encoder.set_buffer(position, Some(data.0), data.1 as u64);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> EncoderParam for &BufferOffset<'a> {
|
||||||
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||||
|
encoder.set_buffer(position, Some(data.buffer), data.offset_in_bytes as u64);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EncoderParam for &mut Buffer {
|
||||||
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||||
|
encoder.set_buffer(position, Some(data), 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EncoderParam for (&mut Buffer, usize) {
|
||||||
|
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
|
||||||
|
encoder.set_buffer(position, Some(data.0), data.1 as u64);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! set_params {
|
||||||
|
($encoder:ident, ($($param:expr),+)) => (
|
||||||
|
let mut _index = 0;
|
||||||
|
$(
|
||||||
|
$crate::utils::set_param($encoder, _index, $param);
|
||||||
|
_index += 1;
|
||||||
|
)*
|
||||||
|
);
|
||||||
|
}
|
Reference in New Issue
Block a user