Metal commands refactoring (#2489)

* Split out the commands part of the metal device.

* Make most fields private.

* Move the allocator back.

* Rework the encoder provider type.
This commit is contained in:
Laurent Mazare
2024-09-21 13:18:42 +02:00
committed by GitHub
parent 5fc4f17727
commit af2104078f
3 changed files with 145 additions and 108 deletions

View File

@ -1864,33 +1864,22 @@ impl BackendDevice for MetalDevice {
fn new(ordinal: usize) -> Result<Self> {
let device = metal::Device::all().swap_remove(ordinal);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer().to_owned();
command_buffer.enqueue();
let command_buffer = Arc::new(RwLock::new(command_buffer));
let command_buffer_index = Arc::new(RwLock::new(0));
let kernels = Arc::new(Kernels::new());
let buffers = Arc::new(RwLock::new(HashMap::new()));
let use_mlx_mm = match std::env::var("CANDLE_USE_MLX_MM").as_deref() {
Ok("false") | Ok("False") | Ok("FALSE") | Ok("0") | Err(_) => false,
Ok(_) => true,
};
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
Ok(val) => val.parse()?,
_ => 50,
};
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
[299792458].as_ptr() as *const c_void,
4,
MTLResourceOptions::StorageModeManaged,
)));
let commands = device::Commands::new(command_queue)?;
Ok(Self {
id: DeviceId::new(),
device,
command_queue,
command_buffer,
command_buffer_index,
compute_per_buffer,
buffers,
commands: Arc::new(RwLock::new(commands)),
buffers: Arc::new(RwLock::new(HashMap::new())),
kernels,
seed,
use_mlx_mm,