Debugging index_add.

This commit is contained in:
Ivar Flakstad
2023-11-03 12:08:58 +01:00
parent f57e3164ae
commit 0794e70a19
5 changed files with 183 additions and 28 deletions

View File

@ -9,7 +9,7 @@ use half::{bf16, f16};
use metal;
use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication};
use metal::mps::{Float32, MPSDataType};
use metal::{MTLResourceOptions, Buffer};
use metal::{Buffer, MTLResourceOptions};
/// Metal related errors
#[derive(thiserror::Error, Debug)]
@ -48,12 +48,9 @@ impl MetalDevice {
self.registry_id()
}
fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer{
fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
let size = (element_count * dtype.size_in_bytes()) as u64;
self.device.new_buffer(
size,
MTLResourceOptions::empty(),
)
self.device.new_buffer(size, MTLResourceOptions::empty())
}
}
@ -80,9 +77,11 @@ impl BackendStorage for MetalStorage {
}
fn to_cpu_storage(&self) -> Result<CpuStorage> {
match self.dtype{
DType::F32 => Ok(CpuStorage::F32(self.buffer.read_to_vec(self.buffer.length() as usize / 4))),
dtype => todo!("Unsupported dtype {dtype:?}")
match self.dtype {
DType::F32 => Ok(CpuStorage::F32(
self.buffer.read_to_vec(self.buffer.length() as usize / 4),
)),
dtype => todo!("Unsupported dtype {dtype:?}"),
}
}
@ -123,7 +122,11 @@ impl BackendStorage for MetalStorage {
let mut buffer = device.new_buffer(el_count, dtype);
todo!("Implement the kernel calling");
// device.kernels.call_unary(U::KERNEL, &self.buffer, &mut buffer, el_count, dtype);
Ok(Self { buffer, device, dtype })
Ok(Self {
buffer,
device,
dtype,
})
}
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
@ -295,7 +298,11 @@ impl MetalStorage {
});
}
if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() {
println!("Didn't implemented non contiguous matmul yet {:?} {:?}", lhs_l.is_contiguous(), rhs_l.is_contiguous());
println!(
"Didn't implemented non contiguous matmul yet {:?} {:?}",
lhs_l.is_contiguous(),
rhs_l.is_contiguous()
);
return Ok(Self {
buffer: out_buffer,
device: self.device.clone(),
@ -361,7 +368,6 @@ impl MetalStorage {
}
}
impl BackendDevice for MetalDevice {
type Storage = MetalStorage;
@ -446,13 +452,25 @@ impl BackendDevice for MetalDevice {
})
}
fn rand_uniform(&self, shape: &Shape, dtype: DType, mean: f64, stddev: f64) -> Result<Self::Storage> {
fn rand_uniform(
&self,
shape: &Shape,
dtype: DType,
mean: f64,
stddev: f64,
) -> Result<Self::Storage> {
// TODO is there a better way ?
let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?;
self.storage_from_cpu_storage(&cpu_storage)
}
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, stddev: f64) -> Result<Self::Storage> {
fn rand_normal(
&self,
shape: &Shape,
dtype: DType,
mean: f64,
stddev: f64,
) -> Result<Self::Storage> {
// TODO is there a better way ?
let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?;
self.storage_from_cpu_storage(&cpu_storage)

View File

@ -349,12 +349,9 @@ impl crate::CustomOp1 for QTensor {
// )?;
let cpu_storage = crate::CpuStorage::F32(dst_storage);
use crate::backend::{BackendDevice, BackendStorage};
if let Device::Metal(device) = &self.device{
Ok((
device.storage_from_cpu_storage(&cpu_storage)?,
dst_shape,
))
}else{
if let Device::Metal(device) = &self.device {
Ok((device.storage_from_cpu_storage(&cpu_storage)?, dst_shape))
} else {
crate::bail!("qtensor not on metal device")
}
}

View File

@ -9,7 +9,7 @@ use std::io::Write;
use tokenizers::Tokenizer;
use candle::quantized::{ggml_file, gguf_file};
use candle::{Tensor};
use candle::Tensor;
use candle_transformers::generation::LogitsProcessor;
use candle_transformers::models::quantized_llama as model;

View File

@ -0,0 +1,49 @@
#include <metal_stdlib>
#include <metal_config>
#define METAL_FUNC inline __attribute__((__always_inline__))
using namespace metal;
struct fault_counter {
uint counter;
uint tolerance;
fault_counter(uint tolerance) {
this->counter = 0;
this->tolerance = tolerance;
}
bool quit() {
counter += 1;
return (counter > tolerance);
}
};
constant uint IDS_DIM_SIZE [[function_constant(0)]];
constant uint SRC_DIM_SIZE [[function_constant(1)]];
constant uint DST_DIM_SIZE [[function_constant(2)]];
constant uint LEFT_SIZE [[function_constant(3)]];
constant uint RIGHT_SIZE [[function_constant(4)]];
constant uint NUMEL = LEFT_SIZE * RIGHT_SIZE;
kernel void index_add(
device uint *ids [[buffer(0)]],
device float *inp [[buffer(1)]],
device float *out [[buffer(2)]],
uint grid_size [[threadgroups_per_grid]], // gridDim
uint gid [[thread_position_in_grid]], // blockIdx
uint num_threads [[threads_per_grid]], // blockDim
uint thread_index [[thread_index_in_threadgroup]] // threadIdx
) {
for (uint i = gid * num_threads + thread_index; i < NUMEL; i += num_threads * grid_size) {
const uint pre = i / RIGHT_SIZE;
const uint post = i % RIGHT_SIZE;
for (uint j = 0; j < IDS_DIM_SIZE; j++) {
const uint idx = ids[j];
const uint src_i = (pre * IDS_DIM_SIZE + j) * RIGHT_SIZE + post;
const uint dst_i = (pre * DST_DIM_SIZE + idx) * RIGHT_SIZE + post;
out[dst_i] += inp[src_i];
}
}
}

View File

@ -1,8 +1,9 @@
use metal::{Buffer, Device, Function, Library, CompileOptions};
use metal::{Buffer, CompileOptions, Device, Function, Library, NSUInteger};
use std::collections::HashMap;
use std::sync::RwLock;
static UNARY: &'static str = include_str!("unary.metal");
pub const INDEXING: &str = include_str!("indexing.metal");
pub const UNARY: &str = include_str!("unary.metal");
pub enum Error {}
@ -63,10 +64,16 @@ fn call_unary(func: &Function, input: &Buffer, output: &Buffer, length: usize) {
mod tests {
use super::*;
use metal::{
ComputePipelineDescriptor, MTLResourceOptions, MTLResourceUsage, MTLSize,
CompileOptions, ComputePipelineDescriptor, Device, FunctionConstantValues, MTLDataType,
MTLResourceOptions, MTLResourceUsage, MTLSize, NSUInteger,
};
use std::ffi::c_void;
use std::mem;
fn approx(v: Vec<f32>, digits: i32) -> Vec<f32>{
pub fn void_ptr<T>(v: &T) -> *const c_void {
(v as *const T).cast()
}
fn approx(v: Vec<f32>, digits: i32) -> Vec<f32> {
let b = 10f32.powi(digits);
v.iter().map(|t| f32::round(t * b) / b).collect()
}
@ -80,11 +87,11 @@ mod tests {
let command_buffer = command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
(v.len() * core::mem::size_of::<f32>()) as u64,
v.as_ptr() as *const c_void,
(v.len() * mem::size_of::<f32>()) as u64,
option,
);
let output = device.new_buffer((v.len() * core::mem::size_of::<f32>()) as u64, option);
let output = device.new_buffer((v.len() * mem::size_of::<f32>()) as u64, option);
let library = device
.new_library_with_source(UNARY, &CompileOptions::new())
.expect("Failed to load unary library");
@ -130,9 +137,93 @@ mod tests {
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
let results = output.read_to_vec::<f32>(v.len());
assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]);
assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]);
}
#[test]
fn index_add() {
let device = Device::system_default().expect("no device found");
let options = CompileOptions::new();
let library = device.new_library_with_source(INDEXING, &options).unwrap();
let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let right = [1.0f32; 15];
let index = [0u32, 4, 2];
let ids_dim_size = index.len() as u32;
let src_dim_size = 2u32;
let dst_dim_size = 2u32;
let left_size = left.len() as u32;
let right_size = right.len() as u32;
let numel = left_size * right_size;
let fcv = FunctionConstantValues::new();
fcv.set_constant_value_at_index(void_ptr(&ids_dim_size), MTLDataType::UInt, 0);
fcv.set_constant_value_at_index(void_ptr(&src_dim_size), MTLDataType::UInt, 1);
fcv.set_constant_value_at_index(void_ptr(&dst_dim_size), MTLDataType::UInt, 2);
fcv.set_constant_value_at_index(void_ptr(&left_size), MTLDataType::UInt, 3);
fcv.set_constant_value_at_index(void_ptr(&right_size), MTLDataType::UInt, 4);
let function = library.get_function("index_add", Some(fcv)).unwrap();
let pipeline = device
.new_compute_pipeline_state_with_function(&function)
.unwrap();
let options = MTLResourceOptions::StorageModeShared;
let ids_size = (index.len() * mem::size_of::<u32>()) as NSUInteger;
let input_size = (left.len() * mem::size_of::<f32>()) as NSUInteger;
let output_size = (right.len() * mem::size_of::<f32>()) as NSUInteger;
let ids = device.new_buffer_with_data(void_ptr(&index), ids_size, options);
let inputs = device.new_buffer_with_data(void_ptr(&left), input_size, options);
let outputs = device.new_buffer_with_data(void_ptr(&right), output_size, options);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
let thread_group_memory_length = output_size;
encoder.set_threadgroup_memory_length(0, thread_group_memory_length as NSUInteger);
encoder.use_resource(&ids, MTLResourceUsage::Read);
encoder.use_resource(&inputs, MTLResourceUsage::Read);
encoder.use_resource(&outputs, MTLResourceUsage::Write);
encoder.set_buffer(0, Some(&ids), 0);
encoder.set_buffer(1, Some(&inputs), 0);
encoder.set_buffer(2, Some(&outputs), 0);
let width = 16;
let thread_group_count = MTLSize {
width,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width: (numel as NSUInteger + width) / width,
height: 1,
depth: 1,
};
println!("{:?}", thread_group_count);
println!("{:?}", thread_group_size);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
let expected = vec![
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
];
let result = outputs.read_to_vec::<f32>(right.len());
println!("{:?}", result);
assert_eq!(result, expected);
}
}