mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 20:22:49 +00:00
Reuse buffers on our own reference counts.
This commit is contained in:
@ -298,7 +298,7 @@ pub fn call_unary_contiguous(
|
||||
kernel_name: unary::contiguous::Kernel,
|
||||
length: usize,
|
||||
input: &Buffer,
|
||||
output: &mut Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
@ -320,7 +320,7 @@ pub fn call_unary_strided(
|
||||
input: &Buffer,
|
||||
strides: &[usize],
|
||||
offset: usize,
|
||||
output: &mut Buffer,
|
||||
output: &Buffer,
|
||||
output_offset: usize,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
|
||||
@ -358,7 +358,7 @@ pub fn call_binary_contiguous(
|
||||
length: usize,
|
||||
left: &Buffer,
|
||||
right: &Buffer,
|
||||
output: &mut Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
|
||||
|
||||
@ -386,7 +386,7 @@ pub fn call_binary_strided(
|
||||
right_input: &Buffer,
|
||||
right_strides: &[usize],
|
||||
right_offset: usize,
|
||||
output: &mut Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Binary, name.0)?;
|
||||
|
||||
@ -425,7 +425,7 @@ pub fn call_cast_contiguous(
|
||||
kernel_name: &'static str,
|
||||
length: usize,
|
||||
input: &Buffer,
|
||||
output: &mut Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
|
||||
|
||||
@ -450,7 +450,7 @@ pub fn call_cast_strided(
|
||||
input: &Buffer,
|
||||
input_strides: &[usize],
|
||||
input_offset: usize,
|
||||
output: &mut Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
// println!("Kernel {:?}", kernel_name.0);
|
||||
// assert_eq!(input.length(), output.length());
|
||||
@ -482,7 +482,7 @@ pub fn call_reduce_contiguous(
|
||||
out_length: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
output: &mut Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let elements_to_sum = length / out_length;
|
||||
@ -523,7 +523,7 @@ pub fn call_last_softmax(
|
||||
length: usize,
|
||||
elements_to_sum: usize,
|
||||
input: &Buffer,
|
||||
output: &mut Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
@ -564,7 +564,7 @@ pub fn call_affine(
|
||||
name: &'static str,
|
||||
size: usize,
|
||||
input: &Buffer,
|
||||
output: &mut Buffer,
|
||||
output: &Buffer,
|
||||
mul: f32,
|
||||
add: f32,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
@ -590,7 +590,7 @@ pub fn call_affine_strided(
|
||||
input: &Buffer,
|
||||
input_stride: &[usize],
|
||||
input_offset: usize,
|
||||
output: &mut Buffer,
|
||||
output: &Buffer,
|
||||
mul: f32,
|
||||
add: f32,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
@ -632,7 +632,7 @@ pub fn call_where_cond_strided(
|
||||
(left_stride, left_offset): (&[usize], usize),
|
||||
right: &Buffer,
|
||||
(right_stride, right_offset): (&[usize], usize),
|
||||
output: &mut Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
|
||||
|
||||
@ -675,7 +675,7 @@ pub fn call_index_select(
|
||||
dim: usize,
|
||||
input: &Buffer,
|
||||
ids: &Buffer,
|
||||
output: &mut Buffer,
|
||||
output: &Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let left_size: usize = shape[..dim].iter().product();
|
||||
let right_size: usize = shape[dim + 1..].iter().product();
|
||||
@ -750,7 +750,7 @@ mod tests {
|
||||
name,
|
||||
v.len(),
|
||||
&input,
|
||||
&mut output,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
@ -775,7 +775,7 @@ mod tests {
|
||||
x.len(),
|
||||
&left,
|
||||
&right,
|
||||
&mut output,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
@ -805,7 +805,7 @@ mod tests {
|
||||
&input,
|
||||
strides,
|
||||
offset,
|
||||
&mut output,
|
||||
&output,
|
||||
0,
|
||||
)
|
||||
.unwrap();
|
||||
@ -943,7 +943,7 @@ mod tests {
|
||||
name,
|
||||
v.len(),
|
||||
&input,
|
||||
&mut output,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
@ -984,7 +984,7 @@ mod tests {
|
||||
"affine_float",
|
||||
size,
|
||||
&input,
|
||||
&mut output,
|
||||
&output,
|
||||
mul as f32,
|
||||
add as f32,
|
||||
)
|
||||
@ -1021,7 +1021,7 @@ mod tests {
|
||||
&input,
|
||||
strides,
|
||||
0,
|
||||
&mut output,
|
||||
&output,
|
||||
mul as f32,
|
||||
add as f32,
|
||||
)
|
||||
@ -1119,7 +1119,7 @@ mod tests {
|
||||
dim,
|
||||
&embeddings_buffer,
|
||||
&ids_buffer,
|
||||
&mut dst_buffer,
|
||||
&dst_buffer,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@ -1227,7 +1227,7 @@ mod tests {
|
||||
out_length,
|
||||
&input,
|
||||
0,
|
||||
&mut output,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
@ -1255,7 +1255,7 @@ mod tests {
|
||||
v.len(),
|
||||
last_dim,
|
||||
&input,
|
||||
&mut output,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
@ -1355,7 +1355,7 @@ mod tests {
|
||||
(&left_stride, left_offset),
|
||||
&right,
|
||||
(&cond_stride, cond_offset),
|
||||
&mut output,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
|
Reference in New Issue
Block a user