mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Reuse buffers on our own reference counts.
This commit is contained in:
@ -8,6 +8,7 @@ use half::f16;
|
|||||||
use metal;
|
use metal;
|
||||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
/// Metal related errors
|
/// Metal related errors
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
@ -37,6 +38,7 @@ pub struct MetalDevice {
|
|||||||
command_queue: metal::CommandQueue,
|
command_queue: metal::CommandQueue,
|
||||||
command_buffer: Arc<RwLock<metal::CommandBuffer>>,
|
command_buffer: Arc<RwLock<metal::CommandBuffer>>,
|
||||||
kernels: Arc<candle_metal_kernels::Kernels>,
|
kernels: Arc<candle_metal_kernels::Kernels>,
|
||||||
|
buffers: Arc<RwLock<HashMap<usize, Vec<Arc<Buffer>>>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Debug for MetalDevice {
|
impl std::fmt::Debug for MetalDevice {
|
||||||
@ -87,8 +89,26 @@ impl MetalDevice {
|
|||||||
&self.device
|
&self.device
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
|
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Arc<Buffer >{
|
||||||
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
let size = element_count * dtype.size_in_bytes();
|
||||||
|
let mut buffers = self.buffers.try_write().unwrap();
|
||||||
|
let subbuffers = buffers.entry(size).or_insert(vec![]);
|
||||||
|
|
||||||
|
for sub in &mut *subbuffers{
|
||||||
|
// if sub.retain_count() == 1{
|
||||||
|
// println!("{size} {:?}", );
|
||||||
|
if Arc::strong_count(sub) == 1{
|
||||||
|
return sub.clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let new_buffer = self.device
|
||||||
|
.new_buffer(size as NSUInteger, MTLResourceOptions::StorageModePrivate);
|
||||||
|
let new_buffer = Arc::new(new_buffer);
|
||||||
|
subbuffers.push(new_buffer.clone());
|
||||||
|
new_buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_buffer_managed(&self, size: NSUInteger) -> Buffer {
|
||||||
self.device
|
self.device
|
||||||
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
|
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
|
||||||
}
|
}
|
||||||
@ -105,7 +125,7 @@ impl MetalDevice {
|
|||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct MetalStorage {
|
pub struct MetalStorage {
|
||||||
buffer: metal::Buffer,
|
buffer: Arc<metal::Buffer>,
|
||||||
device: MetalDevice,
|
device: MetalDevice,
|
||||||
dtype: DType,
|
dtype: DType,
|
||||||
}
|
}
|
||||||
@ -126,29 +146,38 @@ impl BackendStorage for MetalStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||||
|
let buffer = self.device.new_buffer_managed(self.buffer.length());
|
||||||
|
{
|
||||||
|
let command = self.device.command_buffer();
|
||||||
|
let blit = command.new_blit_command_encoder();
|
||||||
|
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
||||||
|
blit.end_encoding();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
self.device.wait_until_completed();
|
self.device.wait_until_completed();
|
||||||
|
|
||||||
match self.dtype {
|
match self.dtype {
|
||||||
DType::U8 => Ok(CpuStorage::U8(
|
DType::U8 => Ok(CpuStorage::U8(
|
||||||
self.buffer.read_to_vec(self.buffer.length() as usize),
|
buffer.read_to_vec(buffer.length() as usize),
|
||||||
)),
|
)),
|
||||||
DType::U32 => Ok(CpuStorage::U32(
|
DType::U32 => Ok(CpuStorage::U32(
|
||||||
self.buffer.read_to_vec(self.buffer.length() as usize / 4),
|
buffer.read_to_vec(buffer.length() as usize / 4),
|
||||||
)),
|
)),
|
||||||
DType::I64 => Ok(CpuStorage::I64(
|
DType::I64 => Ok(CpuStorage::I64(
|
||||||
self.buffer.read_to_vec(self.buffer.length() as usize / 8),
|
buffer.read_to_vec(buffer.length() as usize / 8),
|
||||||
)),
|
)),
|
||||||
DType::F16 => Ok(CpuStorage::F16(
|
DType::F16 => Ok(CpuStorage::F16(
|
||||||
self.buffer.read_to_vec(self.buffer.length() as usize / 2),
|
buffer.read_to_vec(buffer.length() as usize / 2),
|
||||||
)),
|
)),
|
||||||
DType::BF16 => Ok(CpuStorage::BF16(
|
DType::BF16 => Ok(CpuStorage::BF16(
|
||||||
self.buffer.read_to_vec(self.buffer.length() as usize / 2),
|
buffer.read_to_vec(buffer.length() as usize / 2),
|
||||||
)),
|
)),
|
||||||
DType::F32 => Ok(CpuStorage::F32(
|
DType::F32 => Ok(CpuStorage::F32(
|
||||||
self.buffer.read_to_vec(self.buffer.length() as usize / 4),
|
buffer.read_to_vec(buffer.length() as usize / 4),
|
||||||
)),
|
)),
|
||||||
DType::F64 => Ok(CpuStorage::F64(
|
DType::F64 => Ok(CpuStorage::F64(
|
||||||
self.buffer.read_to_vec(self.buffer.length() as usize / 8),
|
buffer.read_to_vec(buffer.length() as usize / 8),
|
||||||
)),
|
)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -175,7 +204,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
name,
|
name,
|
||||||
el,
|
el,
|
||||||
&self.buffer,
|
&self.buffer,
|
||||||
&mut buffer,
|
&buffer,
|
||||||
mul as f32,
|
mul as f32,
|
||||||
add as f32,
|
add as f32,
|
||||||
)
|
)
|
||||||
@ -195,7 +224,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
&self.buffer,
|
&self.buffer,
|
||||||
layout.stride(),
|
layout.stride(),
|
||||||
layout.start_offset() * dtype.size_in_bytes(),
|
layout.start_offset() * dtype.size_in_bytes(),
|
||||||
&mut buffer,
|
&buffer,
|
||||||
mul as f32,
|
mul as f32,
|
||||||
add as f32,
|
add as f32,
|
||||||
)
|
)
|
||||||
@ -270,7 +299,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
dst_el,
|
dst_el,
|
||||||
&self.buffer,
|
&self.buffer,
|
||||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||||
&mut buffer,
|
&buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
|
|
||||||
@ -305,7 +334,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
kernel_name,
|
kernel_name,
|
||||||
el_count,
|
el_count,
|
||||||
&self.buffer,
|
&self.buffer,
|
||||||
&mut buffer,
|
&buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
} else {
|
} else {
|
||||||
@ -324,7 +353,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
&self.buffer,
|
&self.buffer,
|
||||||
layout.stride(),
|
layout.stride(),
|
||||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||||
&mut buffer,
|
&buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
}
|
}
|
||||||
@ -382,7 +411,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
kernel_name,
|
kernel_name,
|
||||||
el_count,
|
el_count,
|
||||||
&self.buffer,
|
&self.buffer,
|
||||||
&mut buffer,
|
&buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
} else {
|
} else {
|
||||||
@ -425,7 +454,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
&self.buffer,
|
&self.buffer,
|
||||||
layout.stride(),
|
layout.stride(),
|
||||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||||
&mut buffer,
|
&buffer,
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
@ -481,7 +510,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
el_count,
|
el_count,
|
||||||
&self.buffer,
|
&self.buffer,
|
||||||
&rhs.buffer,
|
&rhs.buffer,
|
||||||
&mut buffer,
|
&buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
} else {
|
} else {
|
||||||
@ -510,7 +539,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
&rhs.buffer,
|
&rhs.buffer,
|
||||||
rhs_l.stride(),
|
rhs_l.stride(),
|
||||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||||
&mut buffer,
|
&buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
}
|
}
|
||||||
@ -551,7 +580,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
(&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
|
(&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
|
||||||
&f.buffer,
|
&f.buffer,
|
||||||
(&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
|
(&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
|
||||||
&mut buffer,
|
&buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@ -661,7 +690,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
dim,
|
dim,
|
||||||
&self.buffer,
|
&self.buffer,
|
||||||
&ids.buffer,
|
&ids.buffer,
|
||||||
&mut buffer,
|
&buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@ -860,7 +889,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
&self.buffer,
|
&self.buffer,
|
||||||
src_l.stride(),
|
src_l.stride(),
|
||||||
src_l.start_offset() * self.dtype.size_in_bytes(),
|
src_l.start_offset() * self.dtype.size_in_bytes(),
|
||||||
&mut dst.buffer,
|
&dst.buffer,
|
||||||
dst_offset * dst.dtype.size_in_bytes(),
|
dst_offset * dst.dtype.size_in_bytes(),
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
@ -869,7 +898,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl MetalStorage {
|
impl MetalStorage {
|
||||||
pub fn new(buffer: Buffer, device: MetalDevice, dtype: DType) -> Self {
|
pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: DType) -> Self {
|
||||||
Self {
|
Self {
|
||||||
buffer,
|
buffer,
|
||||||
device,
|
device,
|
||||||
@ -904,10 +933,12 @@ impl BackendDevice for MetalDevice {
|
|||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned()));
|
let command_buffer = Arc::new(RwLock::new(command_queue.new_command_buffer().to_owned()));
|
||||||
let kernels = Arc::new(Kernels::new());
|
let kernels = Arc::new(Kernels::new());
|
||||||
|
let buffers = Arc::new(RwLock::new(HashMap::new()));
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
device,
|
device,
|
||||||
command_queue,
|
command_queue,
|
||||||
command_buffer,
|
command_buffer,
|
||||||
|
buffers,
|
||||||
kernels,
|
kernels,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -952,7 +983,7 @@ impl BackendDevice for MetalDevice {
|
|||||||
CpuStorage::F64(storage) => self.new_buffer_with_data(storage),
|
CpuStorage::F64(storage) => self.new_buffer_with_data(storage),
|
||||||
};
|
};
|
||||||
Ok(Self::Storage {
|
Ok(Self::Storage {
|
||||||
buffer,
|
buffer: buffer.into(),
|
||||||
device: self.clone(),
|
device: self.clone(),
|
||||||
dtype: storage.dtype(),
|
dtype: storage.dtype(),
|
||||||
})
|
})
|
||||||
|
@ -298,7 +298,7 @@ pub fn call_unary_contiguous(
|
|||||||
kernel_name: unary::contiguous::Kernel,
|
kernel_name: unary::contiguous::Kernel,
|
||||||
length: usize,
|
length: usize,
|
||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
output: &mut Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
@ -320,7 +320,7 @@ pub fn call_unary_strided(
|
|||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
strides: &[usize],
|
strides: &[usize],
|
||||||
offset: usize,
|
offset: usize,
|
||||||
output: &mut Buffer,
|
output: &Buffer,
|
||||||
output_offset: usize,
|
output_offset: usize,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
|
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
|
||||||
@ -358,7 +358,7 @@ pub fn call_binary_contiguous(
|
|||||||
length: usize,
|
length: usize,
|
||||||
left: &Buffer,
|
left: &Buffer,
|
||||||
right: &Buffer,
|
right: &Buffer,
|
||||||
output: &mut Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
|
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
|
||||||
|
|
||||||
@ -386,7 +386,7 @@ pub fn call_binary_strided(
|
|||||||
right_input: &Buffer,
|
right_input: &Buffer,
|
||||||
right_strides: &[usize],
|
right_strides: &[usize],
|
||||||
right_offset: usize,
|
right_offset: usize,
|
||||||
output: &mut Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
|
let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
|
||||||
|
|
||||||
@ -425,7 +425,7 @@ pub fn call_cast_contiguous(
|
|||||||
kernel_name: &'static str,
|
kernel_name: &'static str,
|
||||||
length: usize,
|
length: usize,
|
||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
output: &mut Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||||
|
|
||||||
@ -450,7 +450,7 @@ pub fn call_cast_strided(
|
|||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
input_strides: &[usize],
|
input_strides: &[usize],
|
||||||
input_offset: usize,
|
input_offset: usize,
|
||||||
output: &mut Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
// println!("Kernel {:?}", kernel_name.0);
|
// println!("Kernel {:?}", kernel_name.0);
|
||||||
// assert_eq!(input.length(), output.length());
|
// assert_eq!(input.length(), output.length());
|
||||||
@ -482,7 +482,7 @@ pub fn call_reduce_contiguous(
|
|||||||
out_length: usize,
|
out_length: usize,
|
||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
input_offset: usize,
|
input_offset: usize,
|
||||||
output: &mut Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||||
let elements_to_sum = length / out_length;
|
let elements_to_sum = length / out_length;
|
||||||
@ -523,7 +523,7 @@ pub fn call_last_softmax(
|
|||||||
length: usize,
|
length: usize,
|
||||||
elements_to_sum: usize,
|
elements_to_sum: usize,
|
||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
output: &mut Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
@ -564,7 +564,7 @@ pub fn call_affine(
|
|||||||
name: &'static str,
|
name: &'static str,
|
||||||
size: usize,
|
size: usize,
|
||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
output: &mut Buffer,
|
output: &Buffer,
|
||||||
mul: f32,
|
mul: f32,
|
||||||
add: f32,
|
add: f32,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
@ -590,7 +590,7 @@ pub fn call_affine_strided(
|
|||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
input_stride: &[usize],
|
input_stride: &[usize],
|
||||||
input_offset: usize,
|
input_offset: usize,
|
||||||
output: &mut Buffer,
|
output: &Buffer,
|
||||||
mul: f32,
|
mul: f32,
|
||||||
add: f32,
|
add: f32,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
@ -632,7 +632,7 @@ pub fn call_where_cond_strided(
|
|||||||
(left_stride, left_offset): (&[usize], usize),
|
(left_stride, left_offset): (&[usize], usize),
|
||||||
right: &Buffer,
|
right: &Buffer,
|
||||||
(right_stride, right_offset): (&[usize], usize),
|
(right_stride, right_offset): (&[usize], usize),
|
||||||
output: &mut Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
||||||
|
|
||||||
@ -675,7 +675,7 @@ pub fn call_index_select(
|
|||||||
dim: usize,
|
dim: usize,
|
||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
ids: &Buffer,
|
ids: &Buffer,
|
||||||
output: &mut Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let left_size: usize = shape[..dim].iter().product();
|
let left_size: usize = shape[..dim].iter().product();
|
||||||
let right_size: usize = shape[dim + 1..].iter().product();
|
let right_size: usize = shape[dim + 1..].iter().product();
|
||||||
@ -750,7 +750,7 @@ mod tests {
|
|||||||
name,
|
name,
|
||||||
v.len(),
|
v.len(),
|
||||||
&input,
|
&input,
|
||||||
&mut output,
|
&output,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
@ -775,7 +775,7 @@ mod tests {
|
|||||||
x.len(),
|
x.len(),
|
||||||
&left,
|
&left,
|
||||||
&right,
|
&right,
|
||||||
&mut output,
|
&output,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
@ -805,7 +805,7 @@ mod tests {
|
|||||||
&input,
|
&input,
|
||||||
strides,
|
strides,
|
||||||
offset,
|
offset,
|
||||||
&mut output,
|
&output,
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@ -943,7 +943,7 @@ mod tests {
|
|||||||
name,
|
name,
|
||||||
v.len(),
|
v.len(),
|
||||||
&input,
|
&input,
|
||||||
&mut output,
|
&output,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
@ -984,7 +984,7 @@ mod tests {
|
|||||||
"affine_float",
|
"affine_float",
|
||||||
size,
|
size,
|
||||||
&input,
|
&input,
|
||||||
&mut output,
|
&output,
|
||||||
mul as f32,
|
mul as f32,
|
||||||
add as f32,
|
add as f32,
|
||||||
)
|
)
|
||||||
@ -1021,7 +1021,7 @@ mod tests {
|
|||||||
&input,
|
&input,
|
||||||
strides,
|
strides,
|
||||||
0,
|
0,
|
||||||
&mut output,
|
&output,
|
||||||
mul as f32,
|
mul as f32,
|
||||||
add as f32,
|
add as f32,
|
||||||
)
|
)
|
||||||
@ -1119,7 +1119,7 @@ mod tests {
|
|||||||
dim,
|
dim,
|
||||||
&embeddings_buffer,
|
&embeddings_buffer,
|
||||||
&ids_buffer,
|
&ids_buffer,
|
||||||
&mut dst_buffer,
|
&dst_buffer,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
@ -1227,7 +1227,7 @@ mod tests {
|
|||||||
out_length,
|
out_length,
|
||||||
&input,
|
&input,
|
||||||
0,
|
0,
|
||||||
&mut output,
|
&output,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
@ -1255,7 +1255,7 @@ mod tests {
|
|||||||
v.len(),
|
v.len(),
|
||||||
last_dim,
|
last_dim,
|
||||||
&input,
|
&input,
|
||||||
&mut output,
|
&output,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
@ -1355,7 +1355,7 @@ mod tests {
|
|||||||
(&left_stride, left_offset),
|
(&left_stride, left_offset),
|
||||||
&right,
|
&right,
|
||||||
(&cond_stride, cond_offset),
|
(&cond_stride, cond_offset),
|
||||||
&mut output,
|
&output,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
|
Reference in New Issue
Block a user