Debugging index_add.

This commit is contained in:
Ivar Flakstad
2023-11-03 12:08:58 +01:00
parent f57e3164ae
commit 0794e70a19
5 changed files with 183 additions and 28 deletions

View File

@ -9,7 +9,7 @@ use half::{bf16, f16};
use metal; use metal;
use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication}; use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication};
use metal::mps::{Float32, MPSDataType}; use metal::mps::{Float32, MPSDataType};
use metal::{MTLResourceOptions, Buffer}; use metal::{Buffer, MTLResourceOptions};
/// Metal related errors /// Metal related errors
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
@ -48,12 +48,9 @@ impl MetalDevice {
self.registry_id() self.registry_id()
} }
fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer{ fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
let size = (element_count * dtype.size_in_bytes()) as u64; let size = (element_count * dtype.size_in_bytes()) as u64;
self.device.new_buffer( self.device.new_buffer(size, MTLResourceOptions::empty())
size,
MTLResourceOptions::empty(),
)
} }
} }
@ -80,9 +77,11 @@ impl BackendStorage for MetalStorage {
} }
fn to_cpu_storage(&self) -> Result<CpuStorage> { fn to_cpu_storage(&self) -> Result<CpuStorage> {
match self.dtype{ match self.dtype {
DType::F32 => Ok(CpuStorage::F32(self.buffer.read_to_vec(self.buffer.length() as usize / 4))), DType::F32 => Ok(CpuStorage::F32(
dtype => todo!("Unsupported dtype {dtype:?}") self.buffer.read_to_vec(self.buffer.length() as usize / 4),
)),
dtype => todo!("Unsupported dtype {dtype:?}"),
} }
} }
@ -123,7 +122,11 @@ impl BackendStorage for MetalStorage {
let mut buffer = device.new_buffer(el_count, dtype); let mut buffer = device.new_buffer(el_count, dtype);
todo!("Implement the kernel calling"); todo!("Implement the kernel calling");
// device.kernels.call_unary(U::KERNEL, &self.buffer, &mut buffer, el_count, dtype); // device.kernels.call_unary(U::KERNEL, &self.buffer, &mut buffer, el_count, dtype);
Ok(Self { buffer, device, dtype }) Ok(Self {
buffer,
device,
dtype,
})
} }
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> { fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
@ -295,7 +298,11 @@ impl MetalStorage {
}); });
} }
if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() { if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() {
println!("Didn't implemented non contiguous matmul yet {:?} {:?}", lhs_l.is_contiguous(), rhs_l.is_contiguous()); println!(
"Didn't implemented non contiguous matmul yet {:?} {:?}",
lhs_l.is_contiguous(),
rhs_l.is_contiguous()
);
return Ok(Self { return Ok(Self {
buffer: out_buffer, buffer: out_buffer,
device: self.device.clone(), device: self.device.clone(),
@ -361,7 +368,6 @@ impl MetalStorage {
} }
} }
impl BackendDevice for MetalDevice { impl BackendDevice for MetalDevice {
type Storage = MetalStorage; type Storage = MetalStorage;
@ -446,13 +452,25 @@ impl BackendDevice for MetalDevice {
}) })
} }
fn rand_uniform(&self, shape: &Shape, dtype: DType, mean: f64, stddev: f64) -> Result<Self::Storage> { fn rand_uniform(
&self,
shape: &Shape,
dtype: DType,
mean: f64,
stddev: f64,
) -> Result<Self::Storage> {
// TODO is there a better way ? // TODO is there a better way ?
let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?; let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?;
self.storage_from_cpu_storage(&cpu_storage) self.storage_from_cpu_storage(&cpu_storage)
} }
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, stddev: f64) -> Result<Self::Storage> { fn rand_normal(
&self,
shape: &Shape,
dtype: DType,
mean: f64,
stddev: f64,
) -> Result<Self::Storage> {
// TODO is there a better way ? // TODO is there a better way ?
let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?; let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?;
self.storage_from_cpu_storage(&cpu_storage) self.storage_from_cpu_storage(&cpu_storage)

View File

@ -349,12 +349,9 @@ impl crate::CustomOp1 for QTensor {
// )?; // )?;
let cpu_storage = crate::CpuStorage::F32(dst_storage); let cpu_storage = crate::CpuStorage::F32(dst_storage);
use crate::backend::{BackendDevice, BackendStorage}; use crate::backend::{BackendDevice, BackendStorage};
if let Device::Metal(device) = &self.device{ if let Device::Metal(device) = &self.device {
Ok(( Ok((device.storage_from_cpu_storage(&cpu_storage)?, dst_shape))
device.storage_from_cpu_storage(&cpu_storage)?, } else {
dst_shape,
))
}else{
crate::bail!("qtensor not on metal device") crate::bail!("qtensor not on metal device")
} }
} }

View File

@ -9,7 +9,7 @@ use std::io::Write;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use candle::quantized::{ggml_file, gguf_file}; use candle::quantized::{ggml_file, gguf_file};
use candle::{Tensor}; use candle::Tensor;
use candle_transformers::generation::LogitsProcessor; use candle_transformers::generation::LogitsProcessor;
use candle_transformers::models::quantized_llama as model; use candle_transformers::models::quantized_llama as model;

View File

@ -0,0 +1,49 @@
#include <metal_stdlib>
#include <metal_config>
#define METAL_FUNC inline __attribute__((__always_inline__))
using namespace metal;
struct fault_counter {
uint counter;
uint tolerance;
fault_counter(uint tolerance) {
this->counter = 0;
this->tolerance = tolerance;
}
bool quit() {
counter += 1;
return (counter > tolerance);
}
};
constant uint IDS_DIM_SIZE [[function_constant(0)]];
constant uint SRC_DIM_SIZE [[function_constant(1)]];
constant uint DST_DIM_SIZE [[function_constant(2)]];
constant uint LEFT_SIZE [[function_constant(3)]];
constant uint RIGHT_SIZE [[function_constant(4)]];
constant uint NUMEL = LEFT_SIZE * RIGHT_SIZE;
kernel void index_add(
device uint *ids [[buffer(0)]],
device float *inp [[buffer(1)]],
device float *out [[buffer(2)]],
uint grid_size [[threadgroups_per_grid]], // gridDim
uint gid [[thread_position_in_grid]], // blockIdx
uint num_threads [[threads_per_grid]], // blockDim
uint thread_index [[thread_index_in_threadgroup]] // threadIdx
) {
for (uint i = gid * num_threads + thread_index; i < NUMEL; i += num_threads * grid_size) {
const uint pre = i / RIGHT_SIZE;
const uint post = i % RIGHT_SIZE;
for (uint j = 0; j < IDS_DIM_SIZE; j++) {
const uint idx = ids[j];
const uint src_i = (pre * IDS_DIM_SIZE + j) * RIGHT_SIZE + post;
const uint dst_i = (pre * DST_DIM_SIZE + idx) * RIGHT_SIZE + post;
out[dst_i] += inp[src_i];
}
}
}

View File

@ -1,8 +1,9 @@
use metal::{Buffer, Device, Function, Library, CompileOptions}; use metal::{Buffer, CompileOptions, Device, Function, Library, NSUInteger};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::RwLock; use std::sync::RwLock;
static UNARY: &'static str = include_str!("unary.metal"); pub const INDEXING: &str = include_str!("indexing.metal");
pub const UNARY: &str = include_str!("unary.metal");
pub enum Error {} pub enum Error {}
@ -63,10 +64,16 @@ fn call_unary(func: &Function, input: &Buffer, output: &Buffer, length: usize) {
mod tests { mod tests {
use super::*; use super::*;
use metal::{ use metal::{
ComputePipelineDescriptor, MTLResourceOptions, MTLResourceUsage, MTLSize, CompileOptions, ComputePipelineDescriptor, Device, FunctionConstantValues, MTLDataType,
MTLResourceOptions, MTLResourceUsage, MTLSize, NSUInteger,
}; };
use std::ffi::c_void;
use std::mem;
fn approx(v: Vec<f32>, digits: i32) -> Vec<f32>{ pub fn void_ptr<T>(v: &T) -> *const c_void {
(v as *const T).cast()
}
fn approx(v: Vec<f32>, digits: i32) -> Vec<f32> {
let b = 10f32.powi(digits); let b = 10f32.powi(digits);
v.iter().map(|t| f32::round(t * b) / b).collect() v.iter().map(|t| f32::round(t * b) / b).collect()
} }
@ -80,11 +87,11 @@ mod tests {
let command_buffer = command_queue.new_command_buffer(); let command_buffer = command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
let input = device.new_buffer_with_data( let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void, v.as_ptr() as *const c_void,
(v.len() * core::mem::size_of::<f32>()) as u64, (v.len() * mem::size_of::<f32>()) as u64,
option, option,
); );
let output = device.new_buffer((v.len() * core::mem::size_of::<f32>()) as u64, option); let output = device.new_buffer((v.len() * mem::size_of::<f32>()) as u64, option);
let library = device let library = device
.new_library_with_source(UNARY, &CompileOptions::new()) .new_library_with_source(UNARY, &CompileOptions::new())
.expect("Failed to load unary library"); .expect("Failed to load unary library");
@ -130,9 +137,93 @@ mod tests {
encoder.end_encoding(); encoder.end_encoding();
command_buffer.commit(); command_buffer.commit();
command_buffer.wait_until_completed(); command_buffer.wait_until_completed();
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect(); let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
let results = output.read_to_vec::<f32>(v.len()); let results = output.read_to_vec::<f32>(v.len());
assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]); assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]);
assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]); assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]);
} }
#[test]
fn index_add() {
let device = Device::system_default().expect("no device found");
let options = CompileOptions::new();
let library = device.new_library_with_source(INDEXING, &options).unwrap();
let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let right = [1.0f32; 15];
let index = [0u32, 4, 2];
let ids_dim_size = index.len() as u32;
let src_dim_size = 2u32;
let dst_dim_size = 2u32;
let left_size = left.len() as u32;
let right_size = right.len() as u32;
let numel = left_size * right_size;
let fcv = FunctionConstantValues::new();
fcv.set_constant_value_at_index(void_ptr(&ids_dim_size), MTLDataType::UInt, 0);
fcv.set_constant_value_at_index(void_ptr(&src_dim_size), MTLDataType::UInt, 1);
fcv.set_constant_value_at_index(void_ptr(&dst_dim_size), MTLDataType::UInt, 2);
fcv.set_constant_value_at_index(void_ptr(&left_size), MTLDataType::UInt, 3);
fcv.set_constant_value_at_index(void_ptr(&right_size), MTLDataType::UInt, 4);
let function = library.get_function("index_add", Some(fcv)).unwrap();
let pipeline = device
.new_compute_pipeline_state_with_function(&function)
.unwrap();
let options = MTLResourceOptions::StorageModeShared;
let ids_size = (index.len() * mem::size_of::<u32>()) as NSUInteger;
let input_size = (left.len() * mem::size_of::<f32>()) as NSUInteger;
let output_size = (right.len() * mem::size_of::<f32>()) as NSUInteger;
let ids = device.new_buffer_with_data(void_ptr(&index), ids_size, options);
let inputs = device.new_buffer_with_data(void_ptr(&left), input_size, options);
let outputs = device.new_buffer_with_data(void_ptr(&right), output_size, options);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
let thread_group_memory_length = output_size;
encoder.set_threadgroup_memory_length(0, thread_group_memory_length as NSUInteger);
encoder.use_resource(&ids, MTLResourceUsage::Read);
encoder.use_resource(&inputs, MTLResourceUsage::Read);
encoder.use_resource(&outputs, MTLResourceUsage::Write);
encoder.set_buffer(0, Some(&ids), 0);
encoder.set_buffer(1, Some(&inputs), 0);
encoder.set_buffer(2, Some(&outputs), 0);
let width = 16;
let thread_group_count = MTLSize {
width,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width: (numel as NSUInteger + width) / width,
height: 1,
depth: 1,
};
println!("{:?}", thread_group_count);
println!("{:?}", thread_group_size);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
let expected = vec![
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
];
let result = outputs.read_to_vec::<f32>(right.len());
println!("{:?}", result);
assert_eq!(result, expected);
}
} }