mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 12:20:46 +00:00
Stash for debugging
This commit is contained in:
@ -795,14 +795,16 @@ impl BackendStorage for MetalStorage {
|
|||||||
rhs_l: &Layout,
|
rhs_l: &Layout,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
// Create descriptors
|
// Create descriptors
|
||||||
let (type_id, size) = match self.dtype {
|
let (type_id, size, name) = match self.dtype {
|
||||||
DType::F32 => (
|
DType::F32 => (
|
||||||
metal::mps::MPS_FLOATBIT_ENCODING | 32,
|
metal::mps::MPS_FLOATBIT_ENCODING | 32,
|
||||||
core::mem::size_of::<f32>() as NSUInteger,
|
core::mem::size_of::<f32>() as NSUInteger,
|
||||||
|
"sgemm",
|
||||||
),
|
),
|
||||||
DType::F16 => (
|
DType::F16 => (
|
||||||
metal::mps::MPS_FLOATBIT_ENCODING | 16,
|
metal::mps::MPS_FLOATBIT_ENCODING | 16,
|
||||||
core::mem::size_of::<f16>() as NSUInteger,
|
core::mem::size_of::<f16>() as NSUInteger,
|
||||||
|
"hgemm",
|
||||||
),
|
),
|
||||||
dtype => todo!("Dtype for matmul {dtype:?} is not supported"),
|
dtype => todo!("Dtype for matmul {dtype:?} is not supported"),
|
||||||
};
|
};
|
||||||
@ -836,60 +838,37 @@ impl BackendStorage for MetalStorage {
|
|||||||
mnk: (m, n, k),
|
mnk: (m, n, k),
|
||||||
})?
|
})?
|
||||||
};
|
};
|
||||||
let b = b as NSUInteger;
|
|
||||||
let m = m as NSUInteger;
|
|
||||||
let n = n as NSUInteger;
|
|
||||||
let k = k as NSUInteger;
|
|
||||||
|
|
||||||
let left_matrix = self.matrix(
|
let result_buffer = self.device.new_buffer(b * m * n, self.dtype);
|
||||||
(b, m, k),
|
|
||||||
transpose_left,
|
|
||||||
size,
|
|
||||||
lhs_l.start_offset() as NSUInteger * size,
|
|
||||||
type_id,
|
|
||||||
)?;
|
|
||||||
let right_matrix = rhs.matrix(
|
|
||||||
(b, k, n),
|
|
||||||
transpose_right,
|
|
||||||
size,
|
|
||||||
rhs_l.start_offset() as NSUInteger * size,
|
|
||||||
type_id,
|
|
||||||
)?;
|
|
||||||
let (result_matrix, out_buffer) =
|
|
||||||
self.device
|
|
||||||
.new_matrix((b, m, n), size, type_id, self.dtype)?;
|
|
||||||
|
|
||||||
let command_buffer = self.device.command_buffer();
|
let command_buffer = self.device.command_buffer();
|
||||||
|
|
||||||
let alpha = 1.0f64;
|
command_buffer.set_label("mfa gemm");
|
||||||
let beta = 0.0f64;
|
|
||||||
// Create kernel
|
candle_metal_kernels::call_mfa_gemm(
|
||||||
let matrix_multiplication = MatrixMultiplication::init(
|
&self.device.device,
|
||||||
&self.device,
|
&command_buffer,
|
||||||
|
&self.device.kernels,
|
||||||
|
name,
|
||||||
|
&self.buffer,
|
||||||
|
lhs_l.shape().dims(),
|
||||||
|
&rhs.buffer,
|
||||||
|
rhs_l.shape().dims(),
|
||||||
|
&result_buffer,
|
||||||
|
(b, m, n, k),
|
||||||
transpose_left,
|
transpose_left,
|
||||||
transpose_right,
|
transpose_right,
|
||||||
m,
|
|
||||||
n,
|
|
||||||
k,
|
|
||||||
alpha,
|
|
||||||
beta,
|
|
||||||
)
|
)
|
||||||
.ok_or_else(|| {
|
.map_err(MetalError::from)?;
|
||||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
|
||||||
})?;
|
|
||||||
|
|
||||||
// Encode kernel to command buffer
|
|
||||||
matrix_multiplication.encode_to_command_buffer(
|
|
||||||
&command_buffer,
|
|
||||||
&left_matrix,
|
|
||||||
&right_matrix,
|
|
||||||
&result_matrix,
|
|
||||||
);
|
|
||||||
command_buffer.set_label("matmul");
|
|
||||||
drop(command_buffer);
|
drop(command_buffer);
|
||||||
self.device.commit();
|
self.device.commit();
|
||||||
|
|
||||||
Ok(Self::new(out_buffer, self.device.clone(), self.dtype()))
|
Ok(Self::new(
|
||||||
|
self.buffer.clone(),
|
||||||
|
self.device.clone(),
|
||||||
|
self.dtype(),
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||||
|
@ -1,7 +1,12 @@
|
|||||||
use metal::{Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger};
|
use metal::{
|
||||||
use std::collections::HashMap;
|
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
|
||||||
|
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLResourceUsage, MTLSize,
|
||||||
|
NSUInteger,
|
||||||
|
};
|
||||||
|
use std::collections::{BTreeMap, HashMap};
|
||||||
use std::ffi::c_void;
|
use std::ffi::c_void;
|
||||||
use std::hash::Hash;
|
use std::hash::Hash;
|
||||||
|
use std::io::{stdout, Write};
|
||||||
use std::sync::RwLock;
|
use std::sync::RwLock;
|
||||||
|
|
||||||
const AFFINE: &str = include_str!("affine.metal");
|
const AFFINE: &str = include_str!("affine.metal");
|
||||||
@ -259,7 +264,10 @@ impl Kernels {
|
|||||||
) -> Result<Function, MetalKernelError> {
|
) -> Result<Function, MetalKernelError> {
|
||||||
let func = self
|
let func = self
|
||||||
.load_library(device, source)?
|
.load_library(device, source)?
|
||||||
.get_function(key.name, key.constants.map(|c| c.create_function_constant_values()))
|
.get_function(
|
||||||
|
key.name,
|
||||||
|
key.constants.map(|c| c.create_function_constant_values()),
|
||||||
|
)
|
||||||
.map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
|
.map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
|
||||||
Ok(func)
|
Ok(func)
|
||||||
}
|
}
|
||||||
@ -292,7 +300,21 @@ struct KernelKey {
|
|||||||
constants: Option<ConstantValues>,
|
constants: Option<ConstantValues>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
impl KernelKey {
|
||||||
|
fn new(name: &'static str) -> Self {
|
||||||
|
Self {
|
||||||
|
name,
|
||||||
|
constants: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn with_constants(mut self, constants: ConstantValues) -> Self {
|
||||||
|
self.constants = Some(constants);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||||
enum ConstantValueId {
|
enum ConstantValueId {
|
||||||
Index(NSUInteger),
|
Index(NSUInteger),
|
||||||
Name(&'static str),
|
Name(&'static str),
|
||||||
@ -306,7 +328,7 @@ macro_rules! metal_dtype {
|
|||||||
impl MetalDType for $ty {
|
impl MetalDType for $ty {
|
||||||
const MTL_DATA_TYPE: MTLDataType = MTLDataType::$mtl_data_type;
|
const MTL_DATA_TYPE: MTLDataType = MTLDataType::$mtl_data_type;
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
}
|
}
|
||||||
metal_dtype!(f32, Float);
|
metal_dtype!(f32, Float);
|
||||||
metal_dtype!(u32, UInt);
|
metal_dtype!(u32, UInt);
|
||||||
@ -314,18 +336,18 @@ metal_dtype!(u16, UShort);
|
|||||||
metal_dtype!(bool, Bool);
|
metal_dtype!(bool, Bool);
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
enum ConstantValue {
|
enum ConstantValueType {
|
||||||
Float(f32),
|
Float(f32),
|
||||||
Uint(u32),
|
Uint(u32),
|
||||||
UShort(u16),
|
UShort(u16),
|
||||||
Bool(bool),
|
Bool(bool),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Hash for ConstantValue {
|
impl Hash for ConstantValueType {
|
||||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||||
use ConstantValue::*;
|
use ConstantValueType::*;
|
||||||
match self {
|
match self {
|
||||||
Float(_) => {}, // do nothing
|
Float(v) => v.to_bits().hash(state),
|
||||||
Uint(v) => v.hash(state),
|
Uint(v) => v.hash(state),
|
||||||
UShort(v) => v.hash(state),
|
UShort(v) => v.hash(state),
|
||||||
Bool(v) => v.hash(state),
|
Bool(v) => v.hash(state),
|
||||||
@ -333,10 +355,10 @@ impl Hash for ConstantValue {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Eq for ConstantValue {}
|
impl Eq for ConstantValueType {}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
struct ConstantValues(Vec<(ConstantValueId, ConstantValue)>);
|
struct ConstantValues(BTreeMap<ConstantValueId, ConstantValueType>);
|
||||||
|
|
||||||
macro_rules! add_indexed_constant {
|
macro_rules! add_indexed_constant {
|
||||||
($fcv:expr, $value:expr, $ty:ty, $idx:expr) => {
|
($fcv:expr, $value:expr, $ty:ty, $idx:expr) => {
|
||||||
@ -356,14 +378,33 @@ macro_rules! add_named_constant {
|
|||||||
)
|
)
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Hash for ConstantValues {
|
||||||
|
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||||
|
for (id, value) in &self.0 {
|
||||||
|
id.hash(state);
|
||||||
|
value.hash(state);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl ConstantValues {
|
impl ConstantValues {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self(BTreeMap::new())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set(mut self, id: impl Into<ConstantValueId>, value: impl Into<ConstantValueType>) -> Self {
|
||||||
|
self.0.insert(id.into(), value.into());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
fn create_function_constant_values(&self) -> FunctionConstantValues {
|
fn create_function_constant_values(&self) -> FunctionConstantValues {
|
||||||
use ConstantValueId::*;
|
use ConstantValueId::*;
|
||||||
use ConstantValue::*;
|
use ConstantValueType::*;
|
||||||
let mut function_values = FunctionConstantValues::new();
|
let mut function_values = FunctionConstantValues::new();
|
||||||
|
|
||||||
for (id, value) in &self.0 {
|
for (id, value) in &self.0 {
|
||||||
match (id, value) {
|
match (&id, &value) {
|
||||||
(Index(index), Float(value)) => {
|
(Index(index), Float(value)) => {
|
||||||
add_indexed_constant!(function_values, value, f32, *index);
|
add_indexed_constant!(function_values, value, f32, *index);
|
||||||
}
|
}
|
||||||
@ -839,42 +880,227 @@ pub fn call_index_select(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<NSUInteger> for ConstantValueId {
|
||||||
|
fn from(idx: NSUInteger) -> Self {
|
||||||
|
Self::Index(idx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<usize> for ConstantValueId {
|
||||||
|
fn from(idx: usize) -> Self {
|
||||||
|
ConstantValueId::from(idx as NSUInteger)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<i32> for ConstantValueId {
|
||||||
|
fn from(idx: i32) -> Self {
|
||||||
|
ConstantValueId::from(idx as NSUInteger)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&'static str> for ConstantValueId {
|
||||||
|
fn from(name: &'static str) -> Self {
|
||||||
|
Self::Name(name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! to_constant_value {
|
||||||
|
($ty:ty, $constant_value_type:ident) => {
|
||||||
|
to_constant_value!($ty, $ty, $constant_value_type);
|
||||||
|
};
|
||||||
|
($ty:ty, $via:ty, $constant_value_type:ident) => {
|
||||||
|
impl From<$ty> for ConstantValueType {
|
||||||
|
fn from(v: $ty) -> Self {
|
||||||
|
Self::$constant_value_type(v as $via)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
to_constant_value!(f32, Float);
|
||||||
|
to_constant_value!(u32, Uint);
|
||||||
|
to_constant_value!(usize, u32, Uint);
|
||||||
|
to_constant_value!(u16, UShort);
|
||||||
|
to_constant_value!(bool, Bool);
|
||||||
|
|
||||||
|
struct MFAGemmConfig {
|
||||||
|
m: usize,
|
||||||
|
k: usize,
|
||||||
|
n: usize,
|
||||||
|
transpose_left: bool,
|
||||||
|
transpose_right: bool,
|
||||||
|
batched: bool,
|
||||||
|
m_simd: u16,
|
||||||
|
n_simd: u16,
|
||||||
|
k_simd: u16,
|
||||||
|
m_splits: u16,
|
||||||
|
n_splits: u16,
|
||||||
|
m_group: u16,
|
||||||
|
n_group: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<MFAGemmConfig> for ConstantValues {
|
||||||
|
fn from(conf: MFAGemmConfig) -> Self {
|
||||||
|
ConstantValues::new()
|
||||||
|
.set(0, conf.m)
|
||||||
|
.set(1, conf.k)
|
||||||
|
.set(2, conf.n)
|
||||||
|
.set(10, conf.transpose_left)
|
||||||
|
.set(11, conf.transpose_right)
|
||||||
|
.set(12, false)
|
||||||
|
.set(20, 1.0)
|
||||||
|
.set(21, 0.0)
|
||||||
|
.set(100, conf.batched)
|
||||||
|
.set(101, false)
|
||||||
|
.set(50001, false)
|
||||||
|
.set(200, conf.m_simd)
|
||||||
|
.set(201, conf.n_simd)
|
||||||
|
.set(202, conf.k_simd)
|
||||||
|
.set(210, conf.m_splits)
|
||||||
|
.set(211, conf.n_splits)
|
||||||
|
// garbage
|
||||||
|
.set(102, false)
|
||||||
|
.set(103, false)
|
||||||
|
.set(113, false)
|
||||||
|
.set(50000, false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_mfa_gemm(
|
pub fn call_mfa_gemm(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
shape: &[usize],
|
lhs: &Buffer,
|
||||||
input: &Buffer,
|
lhs_dims: &[usize],
|
||||||
strides: &[usize],
|
rhs: &Buffer,
|
||||||
offset: usize,
|
rhs_dims: &[usize],
|
||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
output_offset: usize,
|
(b, m, n, k): (usize, usize, usize, usize),
|
||||||
|
transpose_left: bool,
|
||||||
|
transpose_right: bool,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::MetalFlashAttention, name)?;
|
let batched = b > 1;
|
||||||
|
|
||||||
|
let mut c_elements = m * n;
|
||||||
|
if batched {
|
||||||
|
c_elements *= 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
let is_half = name == "hgemm";
|
||||||
|
let is_float = name == "sgemm";
|
||||||
|
|
||||||
|
let mut m_group = 32;
|
||||||
|
let mut n_group = 32;
|
||||||
|
let mut k_simd = 32;
|
||||||
|
if c_elements > 10 ^ 6 {
|
||||||
|
m_group = 48;
|
||||||
|
n_group = 48;
|
||||||
|
}
|
||||||
|
// If K_simd is perfectly equal to matrix K, the compiler can elide a large
|
||||||
|
// amount of logic in the kernel.
|
||||||
|
if k >= 33 && k <= 40 {
|
||||||
|
k_simd = 40;
|
||||||
|
} else if is_half && k >= 73 && k >= 80 {
|
||||||
|
k_simd = 80;
|
||||||
|
} else if c_elements > 10 ^ 6 {
|
||||||
|
if k <= 16 {
|
||||||
|
k_simd = 16;
|
||||||
|
} else if k <= 24 {
|
||||||
|
k_simd = 24;
|
||||||
|
} else if k <= 32 {
|
||||||
|
k_simd = 32;
|
||||||
|
} else if k <= 48 {
|
||||||
|
k_simd = 24;
|
||||||
|
} else if k <= 64 {
|
||||||
|
k_simd = 32;
|
||||||
|
} else if is_float {
|
||||||
|
k_simd = 24;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let m_splits = 2;
|
||||||
|
let n_splits = 2;
|
||||||
|
let m_simd = m_group / m_splits;
|
||||||
|
let n_simd = n_group / n_splits;
|
||||||
|
|
||||||
|
let config = MFAGemmConfig {
|
||||||
|
m,
|
||||||
|
k,
|
||||||
|
n,
|
||||||
|
transpose_left,
|
||||||
|
transpose_right,
|
||||||
|
batched,
|
||||||
|
m_simd,
|
||||||
|
n_simd,
|
||||||
|
k_simd,
|
||||||
|
m_splits,
|
||||||
|
n_splits,
|
||||||
|
m_group,
|
||||||
|
n_group,
|
||||||
|
};
|
||||||
|
|
||||||
|
let pipeline = kernels.load_pipeline(
|
||||||
|
device,
|
||||||
|
Source::MetalFlashAttention,
|
||||||
|
KernelKey::new(name).with_constants(config.into()),
|
||||||
|
)?;
|
||||||
|
let block_type_size = if is_half { 2 } else { 4 };
|
||||||
|
let a_block_bytes = m_group * k_simd * block_type_size;
|
||||||
|
let b_block_bytes = k_simd * n_group * block_type_size;
|
||||||
|
let c_block_bytes = m_group * n_group * block_type_size;
|
||||||
|
let mut thread_group_memory_length = a_block_bytes + b_block_bytes;
|
||||||
|
|
||||||
|
if m % 8 > 0 && n % 8 > 0 {
|
||||||
|
thread_group_memory_length = core::cmp::max(thread_group_memory_length, c_block_bytes);
|
||||||
|
}
|
||||||
|
|
||||||
let num_dims: usize = shape.len();
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
encoder.set_threadgroup_memory_length(0, thread_group_memory_length as NSUInteger);
|
||||||
|
encoder.use_resources(&[&lhs, &rhs], MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(&output, MTLResourceUsage::Write);
|
||||||
|
encoder.set_buffers(0, &[Some(lhs), Some(rhs), Some(output)], &[0; 3]);
|
||||||
|
|
||||||
let length: usize = shape.iter().product();
|
let ceil_divide = |a, b| (a + b - 1) / b;
|
||||||
set_params!(
|
|
||||||
encoder,
|
let mut grid_z = 1;
|
||||||
(
|
|
||||||
length,
|
if batched {
|
||||||
num_dims,
|
grid_z = lhs_dims[..lhs_dims.len() - 2].iter().product();
|
||||||
shape,
|
let byte_stride = |shape: &[usize]| -> u64 {
|
||||||
strides,
|
let rank = shape.len();
|
||||||
(input, offset),
|
let mut output = core::mem::size_of::<f32>() * shape[rank - 2] * shape[rank - 1];
|
||||||
(output, output_offset)
|
if shape[..shape.len() - 2].iter().product::<usize>() == 1 {
|
||||||
)
|
output = 0;
|
||||||
|
}
|
||||||
|
output as u64
|
||||||
|
};
|
||||||
|
let byte_stride_a = byte_stride(lhs_dims);
|
||||||
|
let byte_stride_b = byte_stride(rhs_dims);
|
||||||
|
let byte_stride_c = byte_stride(&[m, n]);
|
||||||
|
|
||||||
|
type BatchConfig = (u64, u64, u64, u64);
|
||||||
|
let mut batching_conf: Vec<BatchConfig> = vec![];
|
||||||
|
for i in 0..grid_z {
|
||||||
|
batching_conf.push((
|
||||||
|
i as u64 * byte_stride_a,
|
||||||
|
i as u64 * byte_stride_b,
|
||||||
|
i as u64 * byte_stride_c,
|
||||||
|
0u64,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
set_param(encoder, 10, batching_conf.as_slice());
|
||||||
|
}
|
||||||
|
|
||||||
|
let grid_size = MTLSize::new(
|
||||||
|
ceil_divide(n as NSUInteger, n_group as NSUInteger),
|
||||||
|
ceil_divide(m as NSUInteger, m_group as NSUInteger),
|
||||||
|
grid_z as NSUInteger,
|
||||||
);
|
);
|
||||||
|
|
||||||
let width: usize = shape.iter().product();
|
let group_size = MTLSize::new((32 * m_splits * n_splits) as NSUInteger, 1, 1);
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
encoder.dispatch_thread_groups(grid_size, group_size);
|
||||||
|
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user