mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Broken metal ?
This commit is contained in:
@ -3,7 +3,7 @@ use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose2D};
|
|||||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||||
use candle_metal_kernels;
|
use candle_metal_kernels;
|
||||||
use candle_metal_kernels::{void_ptr, Kernels, AFFINE};
|
use candle_metal_kernels::{void_ptr, Kernels};
|
||||||
use core::mem;
|
use core::mem;
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use metal;
|
use metal;
|
||||||
@ -36,8 +36,7 @@ impl MetalError {
|
|||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct MetalDevice {
|
pub struct MetalDevice {
|
||||||
device: metal::Device,
|
device: metal::Device,
|
||||||
_command_queue: metal::CommandQueue,
|
command_queue: metal::CommandQueue,
|
||||||
command_buffer: metal::CommandBuffer,
|
|
||||||
kernels: Arc<candle_metal_kernels::Kernels>,
|
kernels: Arc<candle_metal_kernels::Kernels>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -66,13 +65,14 @@ impl MetalDevice {
|
|||||||
|
|
||||||
fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
|
fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
|
||||||
let size = (element_count * dtype.size_in_bytes()) as u64;
|
let size = (element_count * dtype.size_in_bytes()) as u64;
|
||||||
self.device.new_buffer(size, MTLResourceOptions::empty())
|
self.device
|
||||||
|
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct MetalStorage {
|
pub struct MetalStorage {
|
||||||
buffer: Arc<metal::Buffer>,
|
buffer: metal::Buffer,
|
||||||
device: MetalDevice,
|
device: MetalDevice,
|
||||||
dtype: DType,
|
dtype: DType,
|
||||||
}
|
}
|
||||||
@ -103,6 +103,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
|
|
||||||
fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
||||||
let device = self.device().clone();
|
let device = self.device().clone();
|
||||||
|
let command_buffer = self.device.command_queue.new_owned_command_buffer();
|
||||||
|
|
||||||
let shape = layout.shape();
|
let shape = layout.shape();
|
||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
@ -123,7 +124,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
|
|
||||||
let src_length = self.buffer.length() as usize - layout.start_offset();
|
let src_length = self.buffer.length() as usize - layout.start_offset();
|
||||||
let src = self.device.new_buffer(src_length, self.dtype);
|
let src = self.device.new_buffer(src_length, self.dtype);
|
||||||
let blit_encoder = self.device.command_buffer.new_blit_command_encoder();
|
let blit_encoder = command_buffer.new_blit_command_encoder();
|
||||||
blit_encoder.copy_from_buffer(
|
blit_encoder.copy_from_buffer(
|
||||||
self.buffer.as_ref(),
|
self.buffer.as_ref(),
|
||||||
layout.start_offset() as NSUInteger,
|
layout.start_offset() as NSUInteger,
|
||||||
@ -133,7 +134,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
);
|
);
|
||||||
blit_encoder.end_encoding();
|
blit_encoder.end_encoding();
|
||||||
|
|
||||||
let encoder = device.command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
encoder.set_threadgroup_memory_length(0, output_size as NSUInteger);
|
encoder.set_threadgroup_memory_length(0, output_size as NSUInteger);
|
||||||
|
|
||||||
@ -164,6 +165,10 @@ impl BackendStorage for MetalStorage {
|
|||||||
encoder.dispatch_threads(grid_size, thread_group_size);
|
encoder.dispatch_threads(grid_size, thread_group_size);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
|
|
||||||
|
command_buffer.commit();
|
||||||
|
// command_buffer.wait_until_completed();
|
||||||
|
println!("Affine");
|
||||||
|
|
||||||
Ok(self.clone())
|
Ok(self.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -190,17 +195,45 @@ impl BackendStorage for MetalStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||||
let device = self.device().clone();
|
let device = self.device();
|
||||||
let dtype = self.dtype;
|
let dtype = self.dtype;
|
||||||
let shape = layout.shape();
|
let shape = layout.shape();
|
||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
let el_count = shape.elem_count();
|
let el_count = shape.elem_count();
|
||||||
let mut buffer = device.new_buffer(el_count, dtype);
|
let mut buffer = device.new_buffer(el_count, dtype);
|
||||||
//todo!("Implement the kernel calling");
|
let command_buffer = device.command_queue.new_command_buffer();
|
||||||
// device.kernels.call_unary(U::KERNEL, &self.buffer, &mut buffer, el_count, dtype);
|
if layout.is_contiguous() {
|
||||||
|
use candle_metal_kernels::unary::contiguous;
|
||||||
|
|
||||||
|
let kernel_name = match (B::KERNEL, dtype) {
|
||||||
|
("ucos", DType::F32) => contiguous::cos::FLOAT,
|
||||||
|
("usin", DType::F32) => contiguous::sin::FLOAT,
|
||||||
|
("usqr", DType::F32) => contiguous::sqr::FLOAT,
|
||||||
|
("usqrt", DType::F32) => contiguous::sqrt::FLOAT,
|
||||||
|
("uneg", DType::F32) => contiguous::neg::FLOAT,
|
||||||
|
("uexp", DType::F32) => contiguous::exp::FLOAT,
|
||||||
|
(name, dtype) => todo!("Match {name} - {dtype:?}"),
|
||||||
|
};
|
||||||
|
candle_metal_kernels::call_unary_contiguous(
|
||||||
|
&device.device,
|
||||||
|
&command_buffer,
|
||||||
|
&device.kernels,
|
||||||
|
kernel_name,
|
||||||
|
el_count,
|
||||||
|
&self.buffer,
|
||||||
|
&mut buffer,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
} else {
|
||||||
|
todo!("TODO Implement the kernel calling {}", B::KERNEL);
|
||||||
|
}
|
||||||
|
command_buffer.commit();
|
||||||
|
// command_buffer.wait_until_completed();
|
||||||
|
println!("Unary {:?}", B::KERNEL);
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
buffer: Arc::new(buffer),
|
buffer,
|
||||||
device,
|
device: device.clone(),
|
||||||
dtype,
|
dtype,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -368,7 +401,7 @@ impl MetalStorage {
|
|||||||
println!("TODO implement batched matmul for B={b}");
|
println!("TODO implement batched matmul for B={b}");
|
||||||
// bail!("Didn't implemented strided matmul yet");
|
// bail!("Didn't implemented strided matmul yet");
|
||||||
return Ok(Self {
|
return Ok(Self {
|
||||||
buffer: Arc::new(out_buffer),
|
buffer: out_buffer,
|
||||||
device: self.device.clone(),
|
device: self.device.clone(),
|
||||||
dtype: self.dtype(),
|
dtype: self.dtype(),
|
||||||
});
|
});
|
||||||
@ -380,15 +413,17 @@ impl MetalStorage {
|
|||||||
rhs_l.is_contiguous()
|
rhs_l.is_contiguous()
|
||||||
);
|
);
|
||||||
return Ok(Self {
|
return Ok(Self {
|
||||||
buffer: Arc::new(out_buffer),
|
buffer: out_buffer,
|
||||||
device: self.device.clone(),
|
device: self.device.clone(),
|
||||||
dtype: self.dtype(),
|
dtype: self.dtype(),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
println!("GEMM");
|
||||||
|
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||||
encode_gemm::<Float32, Float32, Float32>(
|
encode_gemm::<Float32, Float32, Float32>(
|
||||||
&self.device,
|
&self.device,
|
||||||
&self.device.command_buffer,
|
&command_buffer,
|
||||||
transpose_left,
|
transpose_left,
|
||||||
transpose_right,
|
transpose_right,
|
||||||
&self.buffer,
|
&self.buffer,
|
||||||
@ -402,13 +437,15 @@ impl MetalStorage {
|
|||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
|
|
||||||
println!("lhs {:?} {m} {k}", self.buffer.length());
|
command_buffer.commit();
|
||||||
println!("rhs {:?} {k} {n}", rhs.buffer.length());
|
|
||||||
println!("out {:?} {m} {n}", out_buffer.length());
|
// println!("lhs {:?} {m} {k}", self.buffer.length());
|
||||||
println!("lhs {:?}", lhs_l.shape());
|
// println!("rhs {:?} {k} {n}", rhs.buffer.length());
|
||||||
|
// println!("out {:?} {m} {n}", out_buffer.length());
|
||||||
|
// println!("lhs {:?}", lhs_l.shape());
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
buffer: Arc::new(out_buffer),
|
buffer: out_buffer,
|
||||||
device: self.device.clone(),
|
device: self.device.clone(),
|
||||||
dtype: self.dtype(),
|
dtype: self.dtype(),
|
||||||
})
|
})
|
||||||
@ -423,13 +460,13 @@ impl BackendDevice for MetalDevice {
|
|||||||
|
|
||||||
fn new(ordinal: usize) -> Result<Self> {
|
fn new(ordinal: usize) -> Result<Self> {
|
||||||
let device = metal::Device::all().swap_remove(ordinal);
|
let device = metal::Device::all().swap_remove(ordinal);
|
||||||
let _command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = _command_queue.new_owned_command_buffer();
|
// let command_buffer = _command_queue.new_owned_command_buffer();
|
||||||
let kernels = Arc::new(Kernels::init(&device).map_err(MetalError::from)?);
|
let kernels = Arc::new(Kernels::new());
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
device,
|
device,
|
||||||
_command_queue,
|
command_queue,
|
||||||
command_buffer,
|
// command_buffer,
|
||||||
kernels,
|
kernels,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -498,7 +535,7 @@ impl BackendDevice for MetalDevice {
|
|||||||
),
|
),
|
||||||
};
|
};
|
||||||
Ok(Self::Storage {
|
Ok(Self::Storage {
|
||||||
buffer: Arc::new(buffer),
|
buffer,
|
||||||
device: self.clone(),
|
device: self.clone(),
|
||||||
dtype: storage.dtype(),
|
dtype: storage.dtype(),
|
||||||
})
|
})
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
use metal::{
|
use metal::{
|
||||||
Buffer, CommandBuffer, CompileOptions, ComputePipelineDescriptor, Device, Function, Library,
|
Buffer, CommandBufferRef, CompileOptions, ComputePipelineDescriptor, Device, Function, Library,
|
||||||
MTLSize,
|
MTLSize,
|
||||||
};
|
};
|
||||||
use once_cell::sync::Lazy;
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::ffi::c_void;
|
use std::ffi::c_void;
|
||||||
use std::sync::RwLock;
|
use std::sync::RwLock;
|
||||||
@ -11,11 +10,18 @@ const AFFINE: &str = include_str!("affine.metal");
|
|||||||
const INDEXING: &str = include_str!("indexing.metal");
|
const INDEXING: &str = include_str!("indexing.metal");
|
||||||
const UNARY: &str = include_str!("unary.metal");
|
const UNARY: &str = include_str!("unary.metal");
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
|
pub enum Source {
|
||||||
|
Affine,
|
||||||
|
Indexing,
|
||||||
|
Unary,
|
||||||
|
}
|
||||||
|
|
||||||
macro_rules! unary{
|
macro_rules! unary{
|
||||||
($($name:ident),+) => {
|
($($name:ident),+) => {
|
||||||
|
|
||||||
pub mod contiguous {
|
pub mod contiguous {
|
||||||
pub struct Kernel(pub &'static str);
|
pub struct Kernel(pub(crate) &'static str);
|
||||||
$(
|
$(
|
||||||
pub mod $name {
|
pub mod $name {
|
||||||
use super::Kernel;
|
use super::Kernel;
|
||||||
@ -27,7 +33,7 @@ macro_rules! unary{
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub mod strided {
|
pub mod strided {
|
||||||
pub struct Kernel(pub &'static str);
|
pub struct Kernel(pub(crate) &'static str);
|
||||||
$(
|
$(
|
||||||
pub mod $name {
|
pub mod $name {
|
||||||
use super::Kernel;
|
use super::Kernel;
|
||||||
@ -41,17 +47,17 @@ macro_rules! unary{
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub mod unary {
|
pub mod unary {
|
||||||
unary!(cos, sin, exp);
|
unary!(cos, sin, exp, sqr, sqrt, neg);
|
||||||
}
|
}
|
||||||
|
|
||||||
static LIBRARY_SOURCES: Lazy<HashMap<&'static str, &'static str>> = Lazy::new(|| {
|
// static LIBRARY_SOURCES: Lazy<HashMap<&'static str, &'static str>> = Lazy::new(|| {
|
||||||
let mut l = HashMap::new();
|
// let mut l = HashMap::new();
|
||||||
l.insert("affine", AFFINE);
|
// l.insert("affine", AFFINE);
|
||||||
l.insert("indexing", INDEXING);
|
// l.insert("indexing", INDEXING);
|
||||||
l.insert("unary", UNARY);
|
// l.insert("unary", UNARY);
|
||||||
l
|
// l
|
||||||
});
|
// });
|
||||||
|
//
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
pub enum MetalKernelError {
|
pub enum MetalKernelError {
|
||||||
#[error("Could not lock kernel map: {0}")]
|
#[error("Could not lock kernel map: {0}")]
|
||||||
@ -69,7 +75,7 @@ impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type KernelMap<T> = HashMap<&'static str, T>;
|
type KernelMap<T> = HashMap<&'static str, T>;
|
||||||
type Libraries = KernelMap<Library>;
|
type Libraries = HashMap<Source, Library>;
|
||||||
type Functions = KernelMap<Function>;
|
type Functions = KernelMap<Function>;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -85,39 +91,42 @@ impl Kernels {
|
|||||||
Self { libraries, funcs }
|
Self { libraries, funcs }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn init(device: &Device) -> Result<Self, MetalKernelError> {
|
// pub fn init(device: &Device) -> Result<Self, MetalKernelError> {
|
||||||
let kernels = Self::new();
|
// let kernels = Self::new();
|
||||||
kernels.load_libraries(device)?;
|
// kernels.load_libraries(device)?;
|
||||||
Ok(kernels)
|
// Ok(kernels)
|
||||||
}
|
// }
|
||||||
|
|
||||||
fn load_libraries(&self, device: &Device) -> Result<(), MetalKernelError> {
|
// fn load_libraries(&self, device: &Device) -> Result<(), MetalKernelError> {
|
||||||
for name in LIBRARY_SOURCES.keys() {
|
// for name in LIBRARY_SOURCES.keys() {
|
||||||
self.load_library(device, name)?;
|
// self.load_library(device, name)?;
|
||||||
}
|
// }
|
||||||
Ok(())
|
// Ok(())
|
||||||
}
|
// }
|
||||||
|
|
||||||
fn get_library_source(&self, name: &'static str) -> Option<&'static str> {
|
fn get_library_source(&self, source: Source) -> &'static str {
|
||||||
LIBRARY_SOURCES.get(name).cloned()
|
// LIBRARY_SOURCES.get(name).cloned()
|
||||||
|
match source {
|
||||||
|
Source::Affine => AFFINE,
|
||||||
|
Source::Unary => UNARY,
|
||||||
|
Source::Indexing => INDEXING,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load_library(
|
pub fn load_library(
|
||||||
&self,
|
&self,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
name: &'static str,
|
source: Source,
|
||||||
) -> Result<Library, MetalKernelError> {
|
) -> Result<Library, MetalKernelError> {
|
||||||
let mut libraries = self.libraries.write()?;
|
let mut libraries = self.libraries.write()?;
|
||||||
if let Some(lib) = libraries.get(name) {
|
if let Some(lib) = libraries.get(&source) {
|
||||||
Ok(lib.clone())
|
Ok(lib.clone())
|
||||||
} else {
|
} else {
|
||||||
let source = self.get_library_source(name).ok_or_else(|| {
|
let source_content = self.get_library_source(source);
|
||||||
MetalKernelError::LoadLibraryError(format!("No source found for {}", name))
|
|
||||||
})?;
|
|
||||||
let lib = device
|
let lib = device
|
||||||
.new_library_with_source(source, &CompileOptions::new())
|
.new_library_with_source(source_content, &CompileOptions::new())
|
||||||
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?;
|
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?;
|
||||||
libraries.insert(name, lib.clone());
|
libraries.insert(source, lib.clone());
|
||||||
Ok(lib)
|
Ok(lib)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -125,7 +134,7 @@ impl Kernels {
|
|||||||
pub fn load_function(
|
pub fn load_function(
|
||||||
&self,
|
&self,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
library_name: &'static str,
|
source: Source,
|
||||||
name: &'static str,
|
name: &'static str,
|
||||||
) -> Result<Function, MetalKernelError> {
|
) -> Result<Function, MetalKernelError> {
|
||||||
let mut funcs = self.funcs.write()?;
|
let mut funcs = self.funcs.write()?;
|
||||||
@ -133,7 +142,7 @@ impl Kernels {
|
|||||||
Ok(func.clone())
|
Ok(func.clone())
|
||||||
} else {
|
} else {
|
||||||
let func = self
|
let func = self
|
||||||
.load_library(device, library_name)?
|
.load_library(device, source)?
|
||||||
.get_function(name, None)
|
.get_function(name, None)
|
||||||
.map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
|
.map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
|
||||||
funcs.insert(name, func.clone());
|
funcs.insert(name, func.clone());
|
||||||
@ -144,15 +153,16 @@ impl Kernels {
|
|||||||
|
|
||||||
pub fn call_unary_contiguous(
|
pub fn call_unary_contiguous(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBuffer,
|
command_buffer: &CommandBufferRef,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
kernel_name: unary::contiguous::Kernel,
|
kernel_name: unary::contiguous::Kernel,
|
||||||
length: usize,
|
length: usize,
|
||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
output: &mut Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
assert_eq!(input.length(), output.length());
|
// println!("Kernel {:?}", kernel_name.0);
|
||||||
let func = kernels.load_function(device, "unary", kernel_name.0)?;
|
// assert_eq!(input.length(), output.length());
|
||||||
|
let func = kernels.load_function(device, Source::Unary, kernel_name.0)?;
|
||||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||||
|
|
||||||
@ -188,7 +198,7 @@ pub fn call_unary_contiguous(
|
|||||||
}
|
}
|
||||||
pub fn call_unary_strided(
|
pub fn call_unary_strided(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
command_buffer: &CommandBuffer,
|
command_buffer: &CommandBufferRef,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
name: unary::strided::Kernel,
|
name: unary::strided::Kernel,
|
||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
@ -197,7 +207,7 @@ pub fn call_unary_strided(
|
|||||||
offset: usize,
|
offset: usize,
|
||||||
output: &mut Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let func = kernels.load_function(device, "unary", name.0)?;
|
let func = kernels.load_function(device, Source::Unary, name.0)?;
|
||||||
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
|
||||||
pipeline_state_descriptor.set_compute_function(Some(&func));
|
pipeline_state_descriptor.set_compute_function(Some(&func));
|
||||||
|
|
||||||
@ -277,7 +287,7 @@ mod tests {
|
|||||||
let device = device();
|
let device = device();
|
||||||
let kernels = Kernels::new();
|
let kernels = Kernels::new();
|
||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = command_queue.new_owned_command_buffer();
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
let options = MTLResourceOptions::StorageModeManaged;
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
let input = device.new_buffer_with_data(
|
let input = device.new_buffer_with_data(
|
||||||
v.as_ptr() as *const core::ffi::c_void,
|
v.as_ptr() as *const core::ffi::c_void,
|
||||||
@ -310,7 +320,7 @@ mod tests {
|
|||||||
let device = device();
|
let device = device();
|
||||||
let options = MTLResourceOptions::StorageModeManaged;
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
let command_buffer = command_queue.new_owned_command_buffer();
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
let input = device.new_buffer_with_data(
|
let input = device.new_buffer_with_data(
|
||||||
v.as_ptr() as *const core::ffi::c_void,
|
v.as_ptr() as *const core::ffi::c_void,
|
||||||
(v.len() * core::mem::size_of::<T>()) as u64,
|
(v.len() * core::mem::size_of::<T>()) as u64,
|
||||||
|
@ -22,6 +22,9 @@ METAL_FUNC uint get_strided_index(
|
|||||||
return strided_i;
|
return strided_i;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
|
||||||
|
template <typename T> METAL_FUNC T neg(T in){ return -in; }
|
||||||
|
|
||||||
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
@ -59,12 +62,26 @@ kernel void FN_NAME_STRIDED( \
|
|||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
UNARY(cos, float, cos_float, cos_float_strided);
|
#define UNARY_OP(NAME) \
|
||||||
UNARY(cos, half, cos_half, cos_half_strided);
|
UNARY(NAME, float, NAME##_float, NAME##_float_strided); \
|
||||||
UNARY(sin, float, sin_float, sin_float_strided);
|
UNARY(NAME, half, NAME##_half, NAME##_half_strided);
|
||||||
UNARY(sin, half, sin_half, sin_half_strided);
|
|
||||||
|
#define BFLOAT_UNARY_OP(NAME) \
|
||||||
|
UNARY(NAME, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
|
||||||
|
|
||||||
|
|
||||||
|
UNARY_OP(cos)
|
||||||
|
UNARY_OP(sin)
|
||||||
|
UNARY_OP(sqr)
|
||||||
|
UNARY_OP(sqrt)
|
||||||
|
UNARY_OP(neg)
|
||||||
|
UNARY_OP(exp)
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
#if __METAL_VERSION__ >= 310
|
||||||
UNARY(cos, bfloat, cos_bfloat, cos_bfloat_strided);
|
BFLOAT_UNARY_OP(cos)
|
||||||
UNARY(sin, bfloat, sin_bfloat, sin_bfloat_strided);
|
BFLOAT_UNARY_OP(sin)
|
||||||
|
BFLOAT_UNARY_OP(sqr)
|
||||||
|
BFLOAT_UNARY_OP(sqrt)
|
||||||
|
BFLOAT_UNARY_OP(neg)
|
||||||
|
BFLOAT_UNARY_OP(exp)
|
||||||
#endif
|
#endif
|
||||||
|
Reference in New Issue
Block a user