mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Reworked affine and it works ? No idea how it's different.
This commit is contained in:
@ -55,8 +55,7 @@ tracing-subscriber = "0.3.7"
|
|||||||
wav = "1.0.0"
|
wav = "1.0.0"
|
||||||
yoke = { version = "0.7.2", features = ["derive"] }
|
yoke = { version = "0.7.2", features = ["derive"] }
|
||||||
zip = { version = "0.6.6", default-features = false }
|
zip = { version = "0.6.6", default-features = false }
|
||||||
# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||||
metal = { path = "../metal-rs", features = ["mps"] }
|
|
||||||
|
|
||||||
[profile.release-with-debug]
|
[profile.release-with-debug]
|
||||||
inherits = "release"
|
inherits = "release"
|
||||||
|
@ -111,89 +111,28 @@ impl BackendStorage for MetalStorage {
|
|||||||
let el = shape.elem_count();
|
let el = shape.elem_count();
|
||||||
let dtype = self.dtype;
|
let dtype = self.dtype;
|
||||||
|
|
||||||
debug!("{shape:?} {el:?} {:?}", layout.stride());
|
assert!(layout.is_contiguous());
|
||||||
let output_buffer = device.new_buffer(el, self.dtype);
|
assert_eq!(dtype, DType::F32);
|
||||||
|
|
||||||
|
let mut buffer = device.new_buffer(el, self.dtype);
|
||||||
|
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||||
|
candle_metal_kernels::call_affine(
|
||||||
|
&device.device,
|
||||||
|
&command_buffer,
|
||||||
|
&device.kernels,
|
||||||
|
el,
|
||||||
|
&self.buffer,
|
||||||
|
&mut buffer,
|
||||||
|
mul as f32,
|
||||||
|
add as f32,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
command_buffer.commit();
|
||||||
return Ok(Self {
|
return Ok(Self {
|
||||||
buffer: output_buffer,
|
buffer,
|
||||||
device: device.clone(),
|
device: device.clone(),
|
||||||
dtype,
|
dtype,
|
||||||
});
|
});
|
||||||
let function = self
|
|
||||||
.device
|
|
||||||
.kernels
|
|
||||||
.load_function(&device.device, Source::Affine, "affine")
|
|
||||||
.map_err(MetalError::from)?;
|
|
||||||
|
|
||||||
let pipeline = device
|
|
||||||
.new_compute_pipeline_state_with_function(&function)
|
|
||||||
.map_err(MetalError::msg)?;
|
|
||||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
|
||||||
|
|
||||||
assert_eq!(output_buffer.length(), self.buffer.length());
|
|
||||||
|
|
||||||
let length = el;
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
|
||||||
// encoder.set_threadgroup_memory_length(0, output_size as NSUInteger);
|
|
||||||
|
|
||||||
encoder.set_bytes(0, 4, void_ptr(&el));
|
|
||||||
encoder.set_bytes(1, 4, void_ptr(&dims));
|
|
||||||
encoder.set_bytes(
|
|
||||||
2,
|
|
||||||
(mem::size_of::<usize>() * dims.len()) as u64,
|
|
||||||
dims.as_ptr() as *const core::ffi::c_void,
|
|
||||||
);
|
|
||||||
encoder.set_bytes(
|
|
||||||
3,
|
|
||||||
(mem::size_of::<usize>() * layout.stride().len()) as u64,
|
|
||||||
layout.stride().as_ptr() as *const core::ffi::c_void,
|
|
||||||
);
|
|
||||||
encoder.set_buffer(4, Some(&self.buffer), 0);
|
|
||||||
encoder.set_buffer(5, Some(&output_buffer), 0);
|
|
||||||
|
|
||||||
encoder.set_bytes(6, mem::size_of::<f32>() as u64, void_ptr(&(mul as f32)));
|
|
||||||
encoder.set_bytes(7, mem::size_of::<f32>() as u64, void_ptr(&(add as f32)));
|
|
||||||
|
|
||||||
let grid_size = MTLSize {
|
|
||||||
width: 1,
|
|
||||||
height: 1,
|
|
||||||
depth: 1,
|
|
||||||
};
|
|
||||||
|
|
||||||
let thread_group_size = MTLSize {
|
|
||||||
width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), el as u64),
|
|
||||||
height: 1,
|
|
||||||
depth: 1,
|
|
||||||
};
|
|
||||||
|
|
||||||
encoder.dispatch_thread_groups(grid_size, thread_group_size);
|
|
||||||
encoder.end_encoding();
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
command_buffer.commit();
|
|
||||||
// debug!(
|
|
||||||
// "Affine {:?}({:?}, {:?}) - {:?}",
|
|
||||||
// command_buffer.status(),
|
|
||||||
// self.buffer.length(),
|
|
||||||
// output_buffer.length(),
|
|
||||||
// start.elapsed()
|
|
||||||
// );
|
|
||||||
// command_buffer.wait_until_completed();
|
|
||||||
debug!(
|
|
||||||
"Affine {:?} - {:?}",
|
|
||||||
command_buffer.status(),
|
|
||||||
start.elapsed()
|
|
||||||
);
|
|
||||||
|
|
||||||
// let capture = metal::CaptureManager::shared();
|
|
||||||
// capture.stop_capture();
|
|
||||||
// panic!("Done");
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
buffer: output_buffer,
|
|
||||||
device: device.clone(),
|
|
||||||
dtype,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||||
@ -288,12 +227,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
let el_count = shape.elem_count();
|
let el_count = shape.elem_count();
|
||||||
let mut buffer = device.new_buffer(el_count, dtype);
|
let mut buffer = device.new_buffer(el_count, dtype);
|
||||||
// TODO remove
|
|
||||||
// return Ok(Self {
|
|
||||||
// buffer,
|
|
||||||
// device: device.clone(),
|
|
||||||
// dtype,
|
|
||||||
// });
|
|
||||||
let command_buffer = device.command_queue.new_command_buffer();
|
let command_buffer = device.command_queue.new_command_buffer();
|
||||||
if layout.is_contiguous() {
|
if layout.is_contiguous() {
|
||||||
use candle_metal_kernels::unary::contiguous;
|
use candle_metal_kernels::unary::contiguous;
|
||||||
@ -547,7 +480,11 @@ impl BackendStorage for MetalStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||||
// todo!("TODO Index select {:?} {ids:?} {l:?} {ids_l:?} {dim:?}", self.buffer.length());
|
debug!(
|
||||||
|
"TODO Index select {:?} {:?} {src_l:?} {ids_l:?} {dim:?}",
|
||||||
|
self.buffer.length(),
|
||||||
|
ids.buffer.length(),
|
||||||
|
);
|
||||||
let src = self;
|
let src = self;
|
||||||
let ids_shape = ids_l.shape();
|
let ids_shape = ids_l.shape();
|
||||||
let ids_dims = ids_shape.dims();
|
let ids_dims = ids_shape.dims();
|
||||||
@ -607,8 +544,46 @@ impl BackendStorage for MetalStorage {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
|
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||||
debug!("TODO Copy strided");
|
let src_shape = src_l.shape();
|
||||||
|
let dims = src_shape.dims();
|
||||||
|
let el_count = src_shape.elem_count();
|
||||||
|
if el_count == 0 {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
if src_l.is_contiguous() {
|
||||||
|
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||||
|
let blip = command_buffer.new_blit_command_encoder();
|
||||||
|
blip.copy_from_buffer(
|
||||||
|
&self.buffer,
|
||||||
|
src_l.start_offset() as u64,
|
||||||
|
&dst.buffer,
|
||||||
|
dst_offset as u64,
|
||||||
|
self.buffer.length(),
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||||
|
let kernel_name = match self.dtype {
|
||||||
|
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
|
||||||
|
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
|
||||||
|
DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT,
|
||||||
|
dtype => todo!("copy_strided not implemented for {dtype:?}"),
|
||||||
|
};
|
||||||
|
candle_metal_kernels::call_unary_strided(
|
||||||
|
&self.device.device,
|
||||||
|
&command_buffer,
|
||||||
|
&self.device.kernels,
|
||||||
|
kernel_name,
|
||||||
|
src_l.dims(),
|
||||||
|
&self.buffer,
|
||||||
|
&src_l.stride(),
|
||||||
|
src_l.start_offset(),
|
||||||
|
&mut dst.buffer,
|
||||||
|
dst_offset,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
command_buffer.commit();
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -662,7 +637,7 @@ impl MetalStorage {
|
|||||||
}
|
}
|
||||||
if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() {
|
if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() {
|
||||||
debug!(
|
debug!(
|
||||||
"Didn't implemented non contiguous matmul yet {:?} {:?}",
|
"TODO non contiguous matmul yet {:?} {:?}",
|
||||||
lhs_l.is_contiguous(),
|
lhs_l.is_contiguous(),
|
||||||
rhs_l.is_contiguous()
|
rhs_l.is_contiguous()
|
||||||
);
|
);
|
||||||
@ -674,31 +649,27 @@ impl MetalStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
debug!("GEMM");
|
debug!("GEMM");
|
||||||
// let command_buffer = self.device.command_queue.new_command_buffer();
|
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||||
// encode_gemm::<Float32, Float32, Float32>(
|
encode_gemm::<Float32, Float32, Float32>(
|
||||||
// &self.device,
|
&self.device,
|
||||||
// &command_buffer,
|
&command_buffer,
|
||||||
// transpose_left,
|
transpose_left,
|
||||||
// transpose_right,
|
transpose_right,
|
||||||
// &self.buffer,
|
&self.buffer,
|
||||||
// &rhs.buffer,
|
&rhs.buffer,
|
||||||
// &mut out_buffer,
|
&mut out_buffer,
|
||||||
// m as NSUInteger,
|
m as NSUInteger,
|
||||||
// n as NSUInteger,
|
n as NSUInteger,
|
||||||
// k as NSUInteger,
|
k as NSUInteger,
|
||||||
// alpha,
|
alpha as f32,
|
||||||
// beta,
|
beta as f32,
|
||||||
// )
|
Some(b as NSUInteger),
|
||||||
// .map_err(MetalError::from)?;
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
|
||||||
// command_buffer.commit();
|
command_buffer.commit();
|
||||||
// command_buffer.wait_until_scheduled();
|
// command_buffer.wait_until_scheduled();
|
||||||
|
|
||||||
// println!("lhs {:?} {m} {k}", self.buffer.length());
|
|
||||||
// println!("rhs {:?} {k} {n}", rhs.buffer.length());
|
|
||||||
// println!("out {:?} {m} {n}", out_buffer.length());
|
|
||||||
// println!("lhs {:?}", lhs_l.shape());
|
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
buffer: out_buffer,
|
buffer: out_buffer,
|
||||||
device: self.device.clone(),
|
device: self.device.clone(),
|
||||||
@ -719,7 +690,6 @@ impl BackendDevice for MetalDevice {
|
|||||||
// let capture = metal::CaptureManager::shared();
|
// let capture = metal::CaptureManager::shared();
|
||||||
// let descriptor = metal::CaptureDescriptor::new();
|
// let descriptor = metal::CaptureDescriptor::new();
|
||||||
// descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
|
// descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
|
||||||
// println!("{:?}", std::env::current_dir()?);
|
|
||||||
// descriptor.set_capture_device(&device);
|
// descriptor.set_capture_device(&device);
|
||||||
// let mut dir = std::env::current_dir()?;
|
// let mut dir = std::env::current_dir()?;
|
||||||
// dir.push("out.gputrace");
|
// dir.push("out.gputrace");
|
||||||
|
@ -1,21 +1,4 @@
|
|||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
using namespace metal;
|
|
||||||
|
|
||||||
METAL_FUNC bool is_contiguous(
|
|
||||||
constant size_t &num_dims,
|
|
||||||
constant size_t *dims,
|
|
||||||
constant size_t *strides
|
|
||||||
) {
|
|
||||||
size_t acc = 1;
|
|
||||||
for (uint d = 0; d < num_dims; d++) {
|
|
||||||
uint dim_idx = num_dims - 1 - d;
|
|
||||||
if (acc != strides[dim_idx]) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
acc *= dims[dim_idx];
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
METAL_FUNC uint get_strided_index(
|
METAL_FUNC uint get_strided_index(
|
||||||
uint idx,
|
uint idx,
|
||||||
@ -32,33 +15,30 @@ METAL_FUNC uint get_strided_index(
|
|||||||
return strided_i;
|
return strided_i;
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void affine(
|
using namespace metal;
|
||||||
constant size_t &dim,
|
|
||||||
constant size_t &num_dims,
|
|
||||||
constant size_t *dims,
|
|
||||||
constant size_t *strides,
|
|
||||||
|
|
||||||
device float *inp [[buffer(4)]],
|
#define AFFINE(FN_NAME, TYPENAME) \
|
||||||
device float *out [[buffer(5)]],
|
kernel void FN_NAME( \
|
||||||
|
constant size_t &dim, \
|
||||||
constant float &mul,
|
constant float &mul, \
|
||||||
constant float &add,
|
constant float &add, \
|
||||||
|
device const TYPENAME *input, \
|
||||||
|
device TYPENAME *output, \
|
||||||
uint threadgroup_size [[threads_per_threadgroup]], \
|
uint threadgroup_size [[threads_per_threadgroup]], \
|
||||||
uint thread_index [[thread_index_in_threadgroup]]
|
uint thread_index [[thread_index_in_threadgroup]] \
|
||||||
) {
|
) { \
|
||||||
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size;
|
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
|
||||||
const size_t start = thread_index * length;
|
const size_t start = thread_index * length; \
|
||||||
const size_t stop = min(start + length, dim);
|
const size_t stop = min(start + length, dim); \
|
||||||
if (is_contiguous(num_dims, dims, strides)) {
|
for (size_t i = start; i < stop; i++){ \
|
||||||
for (size_t i = start; i < stop; i++) {
|
output[i] = input[i] * mul + add; \
|
||||||
float x = inp ? inp[i] : out[i];
|
} \
|
||||||
out[i] = x * mul + add;
|
} \
|
||||||
}
|
|
||||||
} else {
|
AFFINE(affine_float, float)
|
||||||
for (size_t i = start; i < stop; i++) {
|
AFFINE(affine_half, half)
|
||||||
uint strided_i = get_strided_index(i, num_dims, dims, strides);
|
|
||||||
float x = inp ? inp[strided_i] : out[strided_i];
|
|
||||||
out[strided_i] = x * mul + add;
|
#if __METAL_VERSION__ >= 310
|
||||||
}
|
AFFINE(affine_bfloat, bfloat);
|
||||||
}
|
#endif
|
||||||
}
|
|
||||||
|
@ -62,7 +62,7 @@ BINARY(FN, float, float, NAME##_float, NAME##_float_strided); \
|
|||||||
BINARY(FN, half, half, NAME##_half, NAME##_half_strided);
|
BINARY(FN, half, half, NAME##_half, NAME##_half_strided);
|
||||||
|
|
||||||
#define BFLOAT_BINARY_OP(FN, NAME) \
|
#define BFLOAT_BINARY_OP(FN, NAME) \
|
||||||
BINARY(NAME, bfloat, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
|
BINARY(FN, bfloat, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
|
||||||
|
|
||||||
|
|
||||||
BINARY_OP(x + y, add)
|
BINARY_OP(x + y, add)
|
||||||
@ -71,8 +71,8 @@ BINARY_OP(x * y, mul)
|
|||||||
BINARY_OP(x / y, div)
|
BINARY_OP(x / y, div)
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
#if __METAL_VERSION__ >= 310
|
||||||
BFLOAT_BINARY_OP(x + y, badd)
|
BFLOAT_BINARY_OP(x + y, add)
|
||||||
BFLOAT_BINARY_OP(x - y, bsub)
|
BFLOAT_BINARY_OP(x - y, sub)
|
||||||
BFLOAT_BINARY_OP(x * y, bmul)
|
BFLOAT_BINARY_OP(x * y, mul)
|
||||||
BFLOAT_BINARY_OP(x / y, bdiv)
|
BFLOAT_BINARY_OP(x / y, div)
|
||||||
#endif
|
#endif
|
||||||
|
@ -51,7 +51,7 @@ macro_rules! ops{
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub mod unary {
|
pub mod unary {
|
||||||
ops!(cos, sin, exp, sqr, sqrt, neg);
|
ops!(cos, sin, exp, sqr, sqrt, neg, copy);
|
||||||
}
|
}
|
||||||
pub mod binary {
|
pub mod binary {
|
||||||
ops!(add, sub, mul, div);
|
ops!(add, sub, mul, div);
|
||||||
@ -210,11 +210,12 @@ pub fn call_unary_strided(
|
|||||||
command_buffer: &CommandBufferRef,
|
command_buffer: &CommandBufferRef,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: unary::strided::Kernel,
|
name: unary::strided::Kernel,
|
||||||
input: &Buffer,
|
|
||||||
shape: &[usize],
|
shape: &[usize],
|
||||||
|
input: &Buffer,
|
||||||
strides: &[usize],
|
strides: &[usize],
|
||||||
offset: usize,
|
offset: usize,
|
||||||
output: &mut Buffer,
|
output: &mut Buffer,
|
||||||
|
output_offset: usize,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let func = kernels.load_function(device, Source::Unary, name.0)?;
|
let func = kernels.load_function(device, Source::Unary, name.0)?;
|
||||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||||
@ -245,7 +246,7 @@ pub fn call_unary_strided(
|
|||||||
);
|
);
|
||||||
|
|
||||||
encoder.set_buffer(4, Some(&input), offset as u64);
|
encoder.set_buffer(4, Some(&input), offset as u64);
|
||||||
encoder.set_buffer(5, Some(&output), 0);
|
encoder.set_buffer(5, Some(&output), output_offset as u64);
|
||||||
|
|
||||||
let width = output.length();
|
let width = output.length();
|
||||||
|
|
||||||
@ -434,6 +435,53 @@ pub fn void_ptr<T>(v: &T) -> *const c_void {
|
|||||||
(v as *const T).cast()
|
(v as *const T).cast()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn call_affine(
|
||||||
|
device: &Device,
|
||||||
|
command_buffer: &CommandBufferRef,
|
||||||
|
kernels: &Kernels,
|
||||||
|
size: usize,
|
||||||
|
input: &Buffer,
|
||||||
|
output: &mut Buffer,
|
||||||
|
mul: f32,
|
||||||
|
add: f32,
|
||||||
|
) -> Result<(), MetalKernelError> {
|
||||||
|
let func = kernels.load_function(device, Source::Affine, "affine_float")?;
|
||||||
|
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||||
|
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||||
|
|
||||||
|
let pipeline = device
|
||||||
|
.new_compute_pipeline_state_with_function(
|
||||||
|
pipeline_state_descriptor.compute_function().unwrap(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
|
encoder.set_bytes(0, core::mem::size_of::<usize>() as u64, void_ptr(&size));
|
||||||
|
encoder.set_bytes(1, core::mem::size_of::<f32>() as u64, void_ptr(&mul));
|
||||||
|
encoder.set_bytes(2, core::mem::size_of::<f32>() as u64, void_ptr(&add));
|
||||||
|
encoder.set_buffer(3, Some(&input), 0);
|
||||||
|
encoder.set_buffer(4, Some(&output), 0);
|
||||||
|
|
||||||
|
let thread_group_count = MTLSize {
|
||||||
|
width: 1,
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size as u64);
|
||||||
|
let thread_group_size = MTLSize {
|
||||||
|
width,
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
encoder.end_encoding();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@ -538,11 +586,12 @@ mod tests {
|
|||||||
&command_buffer,
|
&command_buffer,
|
||||||
&kernels,
|
&kernels,
|
||||||
kernel,
|
kernel,
|
||||||
&input,
|
|
||||||
shape,
|
shape,
|
||||||
|
&input,
|
||||||
strides,
|
strides,
|
||||||
offset,
|
offset,
|
||||||
&mut output,
|
&mut output,
|
||||||
|
0,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
@ -682,82 +731,52 @@ mod tests {
|
|||||||
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
|
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
||||||
fn affine() {
|
|
||||||
let device = device();
|
let device = device();
|
||||||
let options = CompileOptions::new();
|
let kernels = Kernels::new();
|
||||||
let library = device.new_library_with_source(AFFINE, &options).unwrap();
|
|
||||||
|
|
||||||
let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
|
||||||
let output = [2.0f32, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
|
|
||||||
let shape = vec![4usize, 2];
|
|
||||||
let strides = vec![2usize, 1];
|
|
||||||
let mul: f32 = 1.5;
|
|
||||||
let add: f32 = 1.1;
|
|
||||||
|
|
||||||
let function = library.get_function("affine", None).unwrap();
|
|
||||||
let pipeline = device
|
|
||||||
.new_compute_pipeline_state_with_function(&function)
|
|
||||||
.unwrap();
|
|
||||||
let options = MTLResourceOptions::StorageModeManaged;
|
|
||||||
|
|
||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
|
|
||||||
let input_size = (input.len() * mem::size_of::<f32>()) as NSUInteger;
|
let input = device.new_buffer_with_data(
|
||||||
let output_size = (output.len() * mem::size_of::<f32>()) as NSUInteger;
|
v.as_ptr() as *const core::ffi::c_void,
|
||||||
|
(v.len() * core::mem::size_of::<T>()) as u64,
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
options,
|
||||||
encoder.set_threadgroup_memory_length(0, output_size as NSUInteger);
|
|
||||||
|
|
||||||
let inputs_buffer = device.new_buffer_with_data(void_ptr(&input), input_size, options);
|
|
||||||
let outputs_buffer = device.new_buffer_with_data(void_ptr(&output), output_size, options);
|
|
||||||
|
|
||||||
let dim: usize = shape.iter().product();
|
|
||||||
let num_dims = shape.len();
|
|
||||||
encoder.set_bytes(0, core::mem::size_of::<usize>() as u64, void_ptr(&dim));
|
|
||||||
encoder.set_bytes(1, core::mem::size_of::<usize>() as u64, void_ptr(&num_dims));
|
|
||||||
encoder.set_bytes(
|
|
||||||
2,
|
|
||||||
(core::mem::size_of::<usize>() * shape.len()) as u64,
|
|
||||||
shape.as_ptr() as *const c_void,
|
|
||||||
);
|
|
||||||
encoder.set_bytes(
|
|
||||||
3,
|
|
||||||
(core::mem::size_of::<usize>() * strides.len()) as u64,
|
|
||||||
strides.as_ptr() as *const c_void,
|
|
||||||
);
|
);
|
||||||
|
let mut output = device.new_buffer((v.len() * core::mem::size_of::<T>()) as u64, options);
|
||||||
|
|
||||||
encoder.set_buffer(4, Some(&inputs_buffer), 0);
|
let size = v.len();
|
||||||
encoder.set_buffer(5, Some(&outputs_buffer), 0);
|
|
||||||
|
|
||||||
encoder.set_bytes(6, core::mem::size_of::<f32>() as u64, void_ptr(&mul));
|
call_affine(
|
||||||
encoder.set_bytes(7, core::mem::size_of::<f32>() as u64, void_ptr(&add));
|
&device,
|
||||||
|
&command_buffer,
|
||||||
let thread_group_count = MTLSize {
|
&kernels,
|
||||||
width: 1,
|
size,
|
||||||
height: 1,
|
&input,
|
||||||
depth: 1,
|
&mut output,
|
||||||
};
|
mul as f32,
|
||||||
|
add as f32,
|
||||||
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), dim as u64);
|
)
|
||||||
println!("WIDTH {width}");
|
.unwrap();
|
||||||
let thread_group_size = MTLSize {
|
|
||||||
width,
|
|
||||||
height: 1,
|
|
||||||
depth: 1,
|
|
||||||
};
|
|
||||||
|
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
||||||
encoder.end_encoding();
|
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
command_buffer.wait_until_completed();
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
let expected = vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1];
|
output.read_to_vec::<T>(v.len())
|
||||||
let result = outputs_buffer.read_to_vec::<f32>(output.len());
|
}
|
||||||
println!("Result {:?}", result.as_ptr());
|
|
||||||
assert_eq!(result, expected);
|
#[test]
|
||||||
|
fn affine() {
|
||||||
|
let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||||
|
let mul = 1.5;
|
||||||
|
let add = 1.1;
|
||||||
|
let result = run_affine(&input, mul, add);
|
||||||
|
assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]);
|
||||||
|
|
||||||
|
let input = [1.0f32; 40_000];
|
||||||
|
let mul = 1.5;
|
||||||
|
let add = 1.1;
|
||||||
|
let result = run_affine(&input, mul, add);
|
||||||
|
assert_eq!(result, vec![2.6; 40_000]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -826,7 +845,6 @@ mod tests {
|
|||||||
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
|
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
|
||||||
];
|
];
|
||||||
let result = outputs_buffer.read_to_vec::<f32>(right.len());
|
let result = outputs_buffer.read_to_vec::<f32>(right.len());
|
||||||
println!("Result {:?}", result.as_ptr());
|
|
||||||
assert_eq!(result, expected);
|
assert_eq!(result, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@ METAL_FUNC uint get_strided_index(
|
|||||||
|
|
||||||
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
|
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
|
||||||
template <typename T> METAL_FUNC T neg(T in){ return -in; }
|
template <typename T> METAL_FUNC T neg(T in){ return -in; }
|
||||||
|
template <typename T> METAL_FUNC T id(T in){ return in; }
|
||||||
|
|
||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
@ -68,6 +69,8 @@ UNARY_OP(sqr)
|
|||||||
UNARY_OP(sqrt)
|
UNARY_OP(sqrt)
|
||||||
UNARY_OP(neg)
|
UNARY_OP(neg)
|
||||||
UNARY_OP(exp)
|
UNARY_OP(exp)
|
||||||
|
UNARY(id, float, copy_float, copy_float_strided)
|
||||||
|
UNARY(id, half, copy_half, copy_half_strided)
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
#if __METAL_VERSION__ >= 310
|
||||||
BFLOAT_UNARY_OP(cos)
|
BFLOAT_UNARY_OP(cos)
|
||||||
|
Reference in New Issue
Block a user