Files
candle/candle-examples/examples/parler-tts/decode.py
Laurent Mazare 58197e1896 parler-tts support (#2431)
* Start sketching parler-tts support.

* Implement the attention.

* Add the example code.

* Fix the example.

* Add the description + t5 encode it.

* More of the parler forward pass.

* Fix the positional embeddings.

* Support random sampling in generation.

* Handle EOS.

* Add the python decoder.

* Proper causality mask.
2024-08-18 20:42:08 +02:00

30 lines
1.1 KiB
Python

import torch
import torchaudio
from safetensors.torch import load_file
from parler_tts import DACModel
tensors = load_file("out.safetensors")
dac_model = DACModel.from_pretrained("parler-tts/dac_44khZ_8kbps")
output_ids = tensors["codes"][None, None]
print(output_ids, "\n", output_ids.shape)
batch_size = 1
with torch.no_grad():
output_values = []
for sample_id in range(batch_size):
sample = output_ids[:, sample_id]
sample_mask = (sample >= dac_model.config.codebook_size).sum(dim=(0, 1)) == 0
if sample_mask.sum() > 0:
sample = sample[:, :, sample_mask]
sample = dac_model.decode(sample[None, ...], [None]).audio_values
output_values.append(sample.transpose(0, 2))
else:
output_values.append(torch.zeros((1, 1, 1)).to(dac_model.device))
output_lengths = [audio.shape[0] for audio in output_values]
pcm = (
torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0)
.squeeze(-1)
.squeeze(-1)
)
print(pcm.shape, pcm.dtype)
torchaudio.save("out.wav", pcm.cpu(), sample_rate=44100)