mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Affine metal kernel works. Need to extract buffer contents based on layout offset (like CudaSlice.slice) for candle intergration
This commit is contained in:
@ -4,12 +4,13 @@ use crate::error::Error;
|
|||||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||||
use candle_metal_kernels;
|
use candle_metal_kernels;
|
||||||
|
use candle_metal_kernels::{void_ptr, AFFINE};
|
||||||
use core::mem;
|
use core::mem;
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use metal;
|
use metal;
|
||||||
use metal::mps::matrix::encode_gemm;
|
use metal::mps::matrix::encode_gemm;
|
||||||
use metal::mps::Float32;
|
use metal::mps::Float32;
|
||||||
use metal::{Buffer, MTLResourceOptions, NSUInteger};
|
use metal::{Buffer, CompileOptions, MTLResourceOptions, MTLSize, NSUInteger};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
/// Metal related errors
|
/// Metal related errors
|
||||||
@ -86,10 +87,58 @@ impl BackendStorage for MetalStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> {
|
fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
||||||
println!("TODO Affine");
|
let device = self.device().clone();
|
||||||
|
|
||||||
|
/*
|
||||||
|
let shape = layout.shape();
|
||||||
|
let dims = shape.dims();
|
||||||
|
let el = shape.elem_count();
|
||||||
|
|
||||||
|
// TODO: Don't load library every time
|
||||||
|
let library = device.new_library_with_source(AFFINE, &CompileOptions::new()).unwrap();
|
||||||
|
let function = library.get_function("affine", None).unwrap();
|
||||||
|
let pipeline = device
|
||||||
|
.new_compute_pipeline_state_with_function(&function)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let encoder = device.command_buffer.new_compute_command_encoder();
|
||||||
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
|
let output_size = el * self.dtype.size_in_bytes();
|
||||||
|
encoder.set_threadgroup_memory_length(0, output_size as NSUInteger);
|
||||||
|
|
||||||
|
let output_buffer = device.new_buffer(output_size, self.dtype);
|
||||||
|
|
||||||
|
encoder.set_bytes(0, 4, void_ptr(&el));
|
||||||
|
encoder.set_bytes(1, 4, void_ptr(&dims));
|
||||||
|
let info = [dims, layout.stride()].concat();
|
||||||
|
let info_len = (info.len() * mem::size_of::<usize>()) as NSUInteger;
|
||||||
|
encoder.set_bytes(2, info_len, info.as_slice().as_ptr().cast());
|
||||||
|
|
||||||
|
encoder.set_buffer(3, Some(&self.buffer), 0);
|
||||||
|
encoder.set_buffer(4, Some(&output_buffer), 0);
|
||||||
|
|
||||||
|
encoder.set_bytes(5, 4, void_ptr(&(mul as f32)));
|
||||||
|
encoder.set_bytes(6, 4, void_ptr(&(add as f32)));
|
||||||
|
|
||||||
|
let grid_size = MTLSize {
|
||||||
|
width: output_size as NSUInteger,
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
let thread_group_size = MTLSize {
|
||||||
|
width: 1,
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
encoder.dispatch_threads(grid_size, thread_group_size);
|
||||||
|
encoder.end_encoding();
|
||||||
|
*/
|
||||||
|
|
||||||
Ok(self.clone())
|
Ok(self.clone())
|
||||||
// todo!()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||||
|
62
candle-metal-kernels/src/affine.metal
Normal file
62
candle-metal-kernels/src/affine.metal
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
#include <metal_stdlib>
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
METAL_FUNC bool is_contiguous(
|
||||||
|
constant size_t &num_dims,
|
||||||
|
constant size_t *dims,
|
||||||
|
constant size_t *strides
|
||||||
|
) {
|
||||||
|
size_t acc = 1;
|
||||||
|
for (uint d = 0; d < num_dims; d++) {
|
||||||
|
uint dim_idx = num_dims - 1 - d;
|
||||||
|
if (acc != strides[dim_idx]) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
acc *= dims[dim_idx];
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
METAL_FUNC uint get_strided_index(
|
||||||
|
uint idx,
|
||||||
|
constant size_t &num_dims,
|
||||||
|
constant size_t *dims,
|
||||||
|
constant size_t *strides
|
||||||
|
) {
|
||||||
|
uint strided_i = 0;
|
||||||
|
for (uint d = 0; d < num_dims; d++) {
|
||||||
|
uint dim_idx = num_dims - 1 - d;
|
||||||
|
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
|
||||||
|
idx /= dims[dim_idx];
|
||||||
|
}
|
||||||
|
return strided_i;
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void affine(
|
||||||
|
constant size_t &dim,
|
||||||
|
constant size_t &num_dims,
|
||||||
|
constant size_t *info,
|
||||||
|
|
||||||
|
device float *inp [[buffer(3)]],
|
||||||
|
device float *out [[buffer(4)]],
|
||||||
|
|
||||||
|
constant float &mul,
|
||||||
|
constant float &add
|
||||||
|
) {
|
||||||
|
|
||||||
|
constant size_t *dims = info;
|
||||||
|
constant size_t *strides = info + num_dims;
|
||||||
|
|
||||||
|
if (is_contiguous(num_dims, dims, strides)) {
|
||||||
|
for (size_t i = 0; i < dim; i++) {
|
||||||
|
float x = inp ? inp[i] : out[i];
|
||||||
|
out[i] = x * mul + add;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (size_t i = 0; i < dim; i++) {
|
||||||
|
uint strided_i = get_strided_index(i, num_dims, dims, strides);
|
||||||
|
float x = inp ? inp[strided_i] : out[strided_i];
|
||||||
|
out[strided_i] = x * mul + add;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -1,7 +1,9 @@
|
|||||||
use metal::{Buffer, CompileOptions, Device, Function, Library};
|
use metal::{Buffer, CompileOptions, Device, Function, Library};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::ffi::c_void;
|
||||||
use std::sync::RwLock;
|
use std::sync::RwLock;
|
||||||
|
|
||||||
|
pub const AFFINE: &str = include_str!("affine.metal");
|
||||||
pub const INDEXING: &str = include_str!("indexing.metal");
|
pub const INDEXING: &str = include_str!("indexing.metal");
|
||||||
pub const UNARY: &str = include_str!("unary.metal");
|
pub const UNARY: &str = include_str!("unary.metal");
|
||||||
|
|
||||||
@ -60,6 +62,10 @@ fn call_unary(_func: &Function, _input: &Buffer, _output: &Buffer, _length: usiz
|
|||||||
todo!("Call unary");
|
todo!("Call unary");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn void_ptr<T>(v: &T) -> *const c_void {
|
||||||
|
(v as *const T).cast()
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@ -70,9 +76,6 @@ mod tests {
|
|||||||
use std::ffi::c_void;
|
use std::ffi::c_void;
|
||||||
use std::mem;
|
use std::mem;
|
||||||
|
|
||||||
pub fn void_ptr<T>(v: &T) -> *const c_void {
|
|
||||||
(v as *const T).cast()
|
|
||||||
}
|
|
||||||
fn approx(v: Vec<f32>, digits: i32) -> Vec<f32> {
|
fn approx(v: Vec<f32>, digits: i32) -> Vec<f32> {
|
||||||
let b = 10f32.powi(digits);
|
let b = 10f32.powi(digits);
|
||||||
v.iter().map(|t| f32::round(t * b) / b).collect()
|
v.iter().map(|t| f32::round(t * b) / b).collect()
|
||||||
@ -144,6 +147,72 @@ mod tests {
|
|||||||
assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]);
|
assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn affine() {
|
||||||
|
let device = Device::system_default().expect("no device found");
|
||||||
|
|
||||||
|
let options = CompileOptions::new();
|
||||||
|
let library = device.new_library_with_source(AFFINE, &options).unwrap();
|
||||||
|
|
||||||
|
let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||||
|
let output = [2.0f32, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
|
||||||
|
let dim: u32 = 8;
|
||||||
|
let num_dims: u32 = 4;
|
||||||
|
let info = [1u32, 2, 3];
|
||||||
|
let mul: f32 = 1.5;
|
||||||
|
let add: f32 = 1.1;
|
||||||
|
|
||||||
|
let function = library.get_function("affine", None).unwrap();
|
||||||
|
let pipeline = device
|
||||||
|
.new_compute_pipeline_state_with_function(&function)
|
||||||
|
.unwrap();
|
||||||
|
let options = MTLResourceOptions::StorageModeShared;
|
||||||
|
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
|
||||||
|
let input_size = (input.len() * mem::size_of::<f32>()) as NSUInteger;
|
||||||
|
let output_size = (output.len() * mem::size_of::<f32>()) as NSUInteger;
|
||||||
|
|
||||||
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
encoder.set_threadgroup_memory_length(0, output_size as NSUInteger);
|
||||||
|
|
||||||
|
let inputs_buffer = device.new_buffer_with_data(void_ptr(&input), input_size, options);
|
||||||
|
let outputs_buffer = device.new_buffer_with_data(void_ptr(&output), output_size, options);
|
||||||
|
|
||||||
|
encoder.set_bytes(0, 4, void_ptr(&dim));
|
||||||
|
encoder.set_bytes(1, 4, void_ptr(&num_dims));
|
||||||
|
encoder.set_bytes(2, 4, void_ptr(&info));
|
||||||
|
|
||||||
|
encoder.set_buffer(3, Some(&inputs_buffer), 0);
|
||||||
|
encoder.set_buffer(4, Some(&outputs_buffer), 0);
|
||||||
|
|
||||||
|
encoder.set_bytes(5, 4, void_ptr(&mul));
|
||||||
|
encoder.set_bytes(6, 4, void_ptr(&add));
|
||||||
|
|
||||||
|
let grid_size = MTLSize {
|
||||||
|
width: output.len() as NSUInteger,
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
let thread_group_size = MTLSize {
|
||||||
|
width: pipeline.max_total_threads_per_threadgroup(),
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
encoder.dispatch_threads(grid_size, thread_group_size);
|
||||||
|
encoder.end_encoding();
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
|
let expected = vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1];
|
||||||
|
let result = outputs_buffer.read_to_vec::<f32>(output.len());
|
||||||
|
assert_eq!(result, expected);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn index_add() {
|
fn index_add() {
|
||||||
let device = Device::system_default().expect("no device found");
|
let device = Device::system_default().expect("no device found");
|
||||||
|
Reference in New Issue
Block a user