{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Tacotron 2 inference code \n", "Edit the variables **checkpoint_path** and **text** to match yours and run the entire code to generate plots of mel outputs, alignments and audio synthesis from the generated mel-spectrogram using Griffin-Lim." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Import libraries and setup matplotlib" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/dcg-adlr-rafaelvalle-source.cosmos597/repos/nvidia/tacotron2/plotting_utils.py:2: UserWarning: matplotlib.pyplot as already been imported, this call will have no effect.\n", " matplotlib.use(\"Agg\")\n" ] } ], "source": [ "import matplotlib\n", "matplotlib.use(\"Agg\")\n", "import matplotlib.pylab as plt\n", "%matplotlib inline\n", "import IPython.display as ipd\n", "\n", "import sys\n", "sys.path.append('waveglow/')\n", "import numpy as np\n", "import torch\n", "\n", "from hparams import create_hparams\n", "from model import Tacotron2\n", "from layers import TacotronSTFT\n", "from audio_processing import griffin_lim\n", "from train import load_model\n", "from text import text_to_sequence\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def plot_data(data, figsize=(16, 4)):\n", " fig, axes = plt.subplots(1, len(data), figsize=figsize)\n", " for i in range(len(data)):\n", " axes[i].imshow(data[i], aspect='auto', origin='bottom', \n", " interpolation='none')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Setup hparams" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "hparams = create_hparams(\"distributed_run=False,mask_padding=False\")\n", "hparams.sampling_rate = 22050\n", "hparams.filter_length = 1024\n", "hparams.hop_length = 256\n", "hparams.win_length = 1024" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Load model from checkpoint" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "checkpoint_path = \"tacotron2_statedict\"\n", "\n", "model = load_model(hparams)\n", "try:\n", " model = model.module\n", "except:\n", " pass\n", "model.load_state_dict({k.replace('module.',''):v for k,v in torch.load(checkpoint_path)['state_dict'].items()})\n", "_ = model.eval()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Load WaveGlow for mel2audio synthesis" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "waveglow_path = 'waveglow_old.pt'\n", "waveglow = torch.load(waveglow_path)['model']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Prepare text input" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "text = \"Waveglow is really awesome!\"\n", "sequence = np.array(text_to_sequence(text, ['english_cleaners']))[None, :]\n", "sequence = torch.autograd.Variable(\n", " torch.from_numpy(sequence)).cuda().long()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Decode text input and plot results" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "mel_outputs, mel_outputs_postnet, _, alignments = model.inference(sequence)\n", "plot_data((mel_outputs.data.cpu().numpy()[0],\n", " mel_outputs_postnet.data.cpu().numpy()[0],\n", " alignments.data.cpu().numpy()[0].T))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Synthesize audio from spectrogram using WaveGlow" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with torch.no_grad():\n", " audio = waveglow.infer(mel_outputs_postnet, sigma=0.666)\n", "ipd.Audio(audio[0].data.cpu().numpy(), rate=hparams.sampling_rate)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.6" } }, "nbformat": 4, "nbformat_minor": 2 }