Add support for conv_transpose1d for metal backend (#1874)

* first attempt

* progress

* integrate into metal backend

* finish and get test passing

* add other dtype support

* update transpose1d dtypes supported
This commit is contained in:
Thomas Santerre
2024-03-19 03:46:58 -04:00
committed by GitHub
parent 143c481c20
commit 2a8679509e
5 changed files with 394 additions and 10 deletions

View File

@ -54,11 +54,6 @@ fn conv1d(dev: &Device) -> Result<()> {
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
);
// conv-transposes are not implemented for metal.
if dev.is_metal() {
return Ok(());
}
let w = w.transpose(0, 1)?;
// The CPU kernels applied in the contiguous and non contiguous cases are different.
for w in [w.clone(), w.contiguous()?] {