mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
DiffNeXt/unet (#859)
* DiffNeXt/unet * Start adding the vae. * VAE residual block. * VAE forward pass. * Add pixel shuffling. * Actually use pixel shuffling.
This commit is contained in:
@ -444,6 +444,18 @@ impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim> Dims for (D1, D2, D3, D4, D5)
|
||||
}
|
||||
}
|
||||
|
||||
impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim, D5: Dim, D6: Dim> Dims for (D1, D2, D3, D4, D5, D6) {
|
||||
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
|
||||
let d0 = self.0.to_index(shape, op)?;
|
||||
let d1 = self.1.to_index(shape, op)?;
|
||||
let d2 = self.2.to_index(shape, op)?;
|
||||
let d3 = self.3.to_index(shape, op)?;
|
||||
let d4 = self.4.to_index(shape, op)?;
|
||||
let d5 = self.5.to_index(shape, op)?;
|
||||
Ok(vec![d0, d1, d2, d3, d4, d5])
|
||||
}
|
||||
}
|
||||
|
||||
extract_dims!(dims0, 0, |_: &[usize]| (), ());
|
||||
extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
|
||||
extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
|
||||
|
Reference in New Issue
Block a user