mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Affine metal kernel works. Need to extract buffer contents based on layout offset (like CudaSlice.slice) for candle intergration
This commit is contained in:
@ -4,12 +4,13 @@ use crate::error::Error;
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||
use candle_metal_kernels;
|
||||
use candle_metal_kernels::{void_ptr, AFFINE};
|
||||
use core::mem;
|
||||
use half::{bf16, f16};
|
||||
use metal;
|
||||
use metal::mps::matrix::encode_gemm;
|
||||
use metal::mps::Float32;
|
||||
use metal::{Buffer, MTLResourceOptions, NSUInteger};
|
||||
use metal::{Buffer, CompileOptions, MTLResourceOptions, MTLSize, NSUInteger};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Metal related errors
|
||||
@ -86,10 +87,58 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
}
|
||||
|
||||
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> {
|
||||
println!("TODO Affine");
|
||||
fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
|
||||
/*
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
|
||||
// TODO: Don't load library every time
|
||||
let library = device.new_library_with_source(AFFINE, &CompileOptions::new()).unwrap();
|
||||
let function = library.get_function("affine", None).unwrap();
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(&function)
|
||||
.unwrap();
|
||||
|
||||
let encoder = device.command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
let output_size = el * self.dtype.size_in_bytes();
|
||||
encoder.set_threadgroup_memory_length(0, output_size as NSUInteger);
|
||||
|
||||
let output_buffer = device.new_buffer(output_size, self.dtype);
|
||||
|
||||
encoder.set_bytes(0, 4, void_ptr(&el));
|
||||
encoder.set_bytes(1, 4, void_ptr(&dims));
|
||||
let info = [dims, layout.stride()].concat();
|
||||
let info_len = (info.len() * mem::size_of::<usize>()) as NSUInteger;
|
||||
encoder.set_bytes(2, info_len, info.as_slice().as_ptr().cast());
|
||||
|
||||
encoder.set_buffer(3, Some(&self.buffer), 0);
|
||||
encoder.set_buffer(4, Some(&output_buffer), 0);
|
||||
|
||||
encoder.set_bytes(5, 4, void_ptr(&(mul as f32)));
|
||||
encoder.set_bytes(6, 4, void_ptr(&(add as f32)));
|
||||
|
||||
let grid_size = MTLSize {
|
||||
width: output_size as NSUInteger,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let thread_group_size = MTLSize {
|
||||
width: 1,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.dispatch_threads(grid_size, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
*/
|
||||
|
||||
Ok(self.clone())
|
||||
// todo!()
|
||||
}
|
||||
|
||||
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||
|
62
candle-metal-kernels/src/affine.metal
Normal file
62
candle-metal-kernels/src/affine.metal
Normal file
@ -0,0 +1,62 @@
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
METAL_FUNC bool is_contiguous(
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
size_t acc = 1;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
if (acc != strides[dim_idx]) {
|
||||
return false;
|
||||
}
|
||||
acc *= dims[dim_idx];
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
METAL_FUNC uint get_strided_index(
|
||||
uint idx,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *dims,
|
||||
constant size_t *strides
|
||||
) {
|
||||
uint strided_i = 0;
|
||||
for (uint d = 0; d < num_dims; d++) {
|
||||
uint dim_idx = num_dims - 1 - d;
|
||||
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||
idx /= dims[dim_idx];
|
||||
}
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
kernel void affine(
|
||||
constant size_t &dim,
|
||||
constant size_t &num_dims,
|
||||
constant size_t *info,
|
||||
|
||||
device float *inp [[buffer(3)]],
|
||||
device float *out [[buffer(4)]],
|
||||
|
||||
constant float &mul,
|
||||
constant float &add
|
||||
) {
|
||||
|
||||
constant size_t *dims = info;
|
||||
constant size_t *strides = info + num_dims;
|
||||
|
||||
if (is_contiguous(num_dims, dims, strides)) {
|
||||
for (size_t i = 0; i < dim; i++) {
|
||||
float x = inp ? inp[i] : out[i];
|
||||
out[i] = x * mul + add;
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < dim; i++) {
|
||||
uint strided_i = get_strided_index(i, num_dims, dims, strides);
|
||||
float x = inp ? inp[strided_i] : out[strided_i];
|
||||
out[strided_i] = x * mul + add;
|
||||
}
|
||||
}
|
||||
}
|
@ -1,7 +1,9 @@
|
||||
use metal::{Buffer, CompileOptions, Device, Function, Library};
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
use std::sync::RwLock;
|
||||
|
||||
pub const AFFINE: &str = include_str!("affine.metal");
|
||||
pub const INDEXING: &str = include_str!("indexing.metal");
|
||||
pub const UNARY: &str = include_str!("unary.metal");
|
||||
|
||||
@ -60,6 +62,10 @@ fn call_unary(_func: &Function, _input: &Buffer, _output: &Buffer, _length: usiz
|
||||
todo!("Call unary");
|
||||
}
|
||||
|
||||
pub fn void_ptr<T>(v: &T) -> *const c_void {
|
||||
(v as *const T).cast()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@ -70,9 +76,6 @@ mod tests {
|
||||
use std::ffi::c_void;
|
||||
use std::mem;
|
||||
|
||||
pub fn void_ptr<T>(v: &T) -> *const c_void {
|
||||
(v as *const T).cast()
|
||||
}
|
||||
fn approx(v: Vec<f32>, digits: i32) -> Vec<f32> {
|
||||
let b = 10f32.powi(digits);
|
||||
v.iter().map(|t| f32::round(t * b) / b).collect()
|
||||
@ -144,6 +147,72 @@ mod tests {
|
||||
assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn affine() {
|
||||
let device = Device::system_default().expect("no device found");
|
||||
|
||||
let options = CompileOptions::new();
|
||||
let library = device.new_library_with_source(AFFINE, &options).unwrap();
|
||||
|
||||
let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||
let output = [2.0f32, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
|
||||
let dim: u32 = 8;
|
||||
let num_dims: u32 = 4;
|
||||
let info = [1u32, 2, 3];
|
||||
let mul: f32 = 1.5;
|
||||
let add: f32 = 1.1;
|
||||
|
||||
let function = library.get_function("affine", None).unwrap();
|
||||
let pipeline = device
|
||||
.new_compute_pipeline_state_with_function(&function)
|
||||
.unwrap();
|
||||
let options = MTLResourceOptions::StorageModeShared;
|
||||
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
|
||||
let input_size = (input.len() * mem::size_of::<f32>()) as NSUInteger;
|
||||
let output_size = (output.len() * mem::size_of::<f32>()) as NSUInteger;
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
encoder.set_threadgroup_memory_length(0, output_size as NSUInteger);
|
||||
|
||||
let inputs_buffer = device.new_buffer_with_data(void_ptr(&input), input_size, options);
|
||||
let outputs_buffer = device.new_buffer_with_data(void_ptr(&output), output_size, options);
|
||||
|
||||
encoder.set_bytes(0, 4, void_ptr(&dim));
|
||||
encoder.set_bytes(1, 4, void_ptr(&num_dims));
|
||||
encoder.set_bytes(2, 4, void_ptr(&info));
|
||||
|
||||
encoder.set_buffer(3, Some(&inputs_buffer), 0);
|
||||
encoder.set_buffer(4, Some(&outputs_buffer), 0);
|
||||
|
||||
encoder.set_bytes(5, 4, void_ptr(&mul));
|
||||
encoder.set_bytes(6, 4, void_ptr(&add));
|
||||
|
||||
let grid_size = MTLSize {
|
||||
width: output.len() as NSUInteger,
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
let thread_group_size = MTLSize {
|
||||
width: pipeline.max_total_threads_per_threadgroup(),
|
||||
height: 1,
|
||||
depth: 1,
|
||||
};
|
||||
|
||||
encoder.dispatch_threads(grid_size, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
let expected = vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1];
|
||||
let result = outputs_buffer.read_to_vec::<f32>(output.len());
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_add() {
|
||||
let device = Device::system_default().expect("no device found");
|
||||
|
Reference in New Issue
Block a user