You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

541 lines
20 KiB

  1. import torch
  2. from torch.autograd import Variable
  3. from torch import nn
  4. from torch.nn import functional as F
  5. from layers import ConvNorm, LinearNorm
  6. from utils import to_gpu, get_mask_from_lengths
  7. from fp16_optimizer import fp32_to_fp16, fp16_to_fp32
  8. class LocationLayer(nn.Module):
  9. def __init__(self, attention_n_filters, attention_kernel_size,
  10. attention_dim):
  11. super(LocationLayer, self).__init__()
  12. padding = int((attention_kernel_size - 1) / 2)
  13. self.location_conv = ConvNorm(2, attention_n_filters,
  14. kernel_size=attention_kernel_size,
  15. padding=padding, bias=False, stride=1,
  16. dilation=1)
  17. self.location_dense = LinearNorm(attention_n_filters, attention_dim,
  18. bias=False, w_init_gain='tanh')
  19. def forward(self, attention_weights_cat):
  20. processed_attention = self.location_conv(attention_weights_cat)
  21. processed_attention = processed_attention.transpose(1, 2)
  22. processed_attention = self.location_dense(processed_attention)
  23. return processed_attention
  24. class Attention(nn.Module):
  25. def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
  26. attention_location_n_filters, attention_location_kernel_size):
  27. super(Attention, self).__init__()
  28. self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
  29. bias=False, w_init_gain='tanh')
  30. self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
  31. w_init_gain='tanh')
  32. self.v = LinearNorm(attention_dim, 1, bias=False)
  33. self.location_layer = LocationLayer(attention_location_n_filters,
  34. attention_location_kernel_size,
  35. attention_dim)
  36. self.score_mask_value = -float("inf")
  37. def get_alignment_energies(self, query, processed_memory,
  38. attention_weights_cat):
  39. """
  40. PARAMS
  41. ------
  42. query: decoder output (batch, n_mel_channels * n_frames_per_step)
  43. processed_memory: processed encoder outputs (B, T_in, attention_dim)
  44. attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
  45. RETURNS
  46. -------
  47. alignment (batch, max_time)
  48. """
  49. processed_query = self.query_layer(query.unsqueeze(1))
  50. processed_attention_weights = self.location_layer(attention_weights_cat)
  51. energies = self.v(F.tanh(
  52. processed_query + processed_attention_weights + processed_memory))
  53. energies = energies.squeeze(-1)
  54. return energies
  55. def forward(self, attention_hidden_state, memory, processed_memory,
  56. attention_weights_cat, mask):
  57. """
  58. PARAMS
  59. ------
  60. attention_hidden_state: attention rnn last output
  61. memory: encoder outputs
  62. processed_memory: processed encoder outputs
  63. attention_weights_cat: previous and cummulative attention weights
  64. mask: binary mask for padded data
  65. """
  66. alignment = self.get_alignment_energies(
  67. attention_hidden_state, processed_memory, attention_weights_cat)
  68. if mask is not None:
  69. alignment.data.masked_fill_(mask, self.score_mask_value)
  70. attention_weights = F.softmax(alignment, dim=1)
  71. attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
  72. attention_context = attention_context.squeeze(1)
  73. return attention_context, attention_weights
  74. class Prenet(nn.Module):
  75. def __init__(self, in_dim, sizes):
  76. super(Prenet, self).__init__()
  77. in_sizes = [in_dim] + sizes[:-1]
  78. self.layers = nn.ModuleList(
  79. [LinearNorm(in_size, out_size, bias=False)
  80. for (in_size, out_size) in zip(in_sizes, sizes)])
  81. def forward(self, x):
  82. for linear in self.layers:
  83. x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
  84. return x
  85. class Postnet(nn.Module):
  86. """Postnet
  87. - Five 1-d convolution with 512 channels and kernel size 5
  88. """
  89. def __init__(self, hparams):
  90. super(Postnet, self).__init__()
  91. self.dropout = nn.Dropout(0.5)
  92. self.convolutions = nn.ModuleList()
  93. self.convolutions.append(
  94. nn.Sequential(
  95. ConvNorm(hparams.n_mel_channels, hparams.postnet_embedding_dim,
  96. kernel_size=hparams.postnet_kernel_size, stride=1,
  97. padding=int((hparams.postnet_kernel_size - 1) / 2),
  98. dilation=1, w_init_gain='tanh'),
  99. nn.BatchNorm1d(hparams.postnet_embedding_dim))
  100. )
  101. for i in range(1, hparams.postnet_n_convolutions - 1):
  102. self.convolutions.append(
  103. nn.Sequential(
  104. ConvNorm(hparams.postnet_embedding_dim,
  105. hparams.postnet_embedding_dim,
  106. kernel_size=hparams.postnet_kernel_size, stride=1,
  107. padding=int((hparams.postnet_kernel_size - 1) / 2),
  108. dilation=1, w_init_gain='tanh'),
  109. nn.BatchNorm1d(hparams.postnet_embedding_dim))
  110. )
  111. self.convolutions.append(
  112. nn.Sequential(
  113. ConvNorm(hparams.postnet_embedding_dim, hparams.n_mel_channels,
  114. kernel_size=hparams.postnet_kernel_size, stride=1,
  115. padding=int((hparams.postnet_kernel_size - 1) / 2),
  116. dilation=1, w_init_gain='linear'),
  117. nn.BatchNorm1d(hparams.n_mel_channels))
  118. )
  119. def forward(self, x):
  120. for i in range(len(self.convolutions) - 1):
  121. x = self.dropout(F.tanh(self.convolutions[i](x)))
  122. x = self.dropout(self.convolutions[-1](x))
  123. return x
  124. class Encoder(nn.Module):
  125. """Encoder module:
  126. - Three 1-d convolution banks
  127. - Bidirectional LSTM
  128. """
  129. def __init__(self, hparams):
  130. super(Encoder, self).__init__()
  131. self.dropout = nn.Dropout(0.5)
  132. convolutions = []
  133. for _ in range(hparams.encoder_n_convolutions):
  134. conv_layer = nn.Sequential(
  135. ConvNorm(hparams.encoder_embedding_dim,
  136. hparams.encoder_embedding_dim,
  137. kernel_size=hparams.encoder_kernel_size, stride=1,
  138. padding=int((hparams.encoder_kernel_size - 1) / 2),
  139. dilation=1, w_init_gain='relu'),
  140. nn.BatchNorm1d(hparams.encoder_embedding_dim))
  141. convolutions.append(conv_layer)
  142. self.convolutions = nn.ModuleList(convolutions)
  143. self.lstm = nn.LSTM(hparams.encoder_embedding_dim,
  144. int(hparams.encoder_embedding_dim / 2), 1,
  145. batch_first=True, bidirectional=True)
  146. def forward(self, x, input_lengths):
  147. for conv in self.convolutions:
  148. x = self.dropout(F.relu(conv(x)))
  149. x = x.transpose(1, 2)
  150. # pytorch tensor are not reversible, hence the conversion
  151. input_lengths = input_lengths.cpu().numpy()
  152. x = nn.utils.rnn.pack_padded_sequence(
  153. x, input_lengths, batch_first=True)
  154. self.lstm.flatten_parameters()
  155. outputs, _ = self.lstm(x)
  156. outputs, _ = nn.utils.rnn.pad_packed_sequence(
  157. outputs, batch_first=True)
  158. return outputs
  159. def inference(self, x):
  160. for conv in self.convolutions:
  161. x = self.dropout(F.relu(conv(x)))
  162. x = x.transpose(1, 2)
  163. self.lstm.flatten_parameters()
  164. outputs, _ = self.lstm(x)
  165. return outputs
  166. class Decoder(nn.Module):
  167. def __init__(self, hparams):
  168. super(Decoder, self).__init__()
  169. self.n_mel_channels = hparams.n_mel_channels
  170. self.n_frames_per_step = hparams.n_frames_per_step
  171. self.encoder_embedding_dim = hparams.encoder_embedding_dim
  172. self.attention_rnn_dim = hparams.attention_rnn_dim
  173. self.decoder_rnn_dim = hparams.decoder_rnn_dim
  174. self.prenet_dim = hparams.prenet_dim
  175. self.max_decoder_steps = hparams.max_decoder_steps
  176. self.gate_threshold = hparams.gate_threshold
  177. self.prenet = Prenet(
  178. hparams.n_mel_channels * hparams.n_frames_per_step,
  179. [hparams.prenet_dim, hparams.prenet_dim])
  180. self.attention_rnn = nn.LSTMCell(
  181. hparams.prenet_dim + hparams.encoder_embedding_dim,
  182. hparams.attention_rnn_dim)
  183. self.attention_layer = Attention(
  184. hparams.attention_rnn_dim, hparams.encoder_embedding_dim,
  185. hparams.attention_dim, hparams.attention_location_n_filters,
  186. hparams.attention_location_kernel_size)
  187. self.decoder_rnn = nn.LSTMCell(
  188. hparams.attention_rnn_dim + hparams.encoder_embedding_dim,
  189. hparams.decoder_rnn_dim, 1)
  190. self.linear_projection = LinearNorm(
  191. hparams.decoder_rnn_dim + hparams.encoder_embedding_dim,
  192. hparams.n_mel_channels*hparams.n_frames_per_step)
  193. self.gate_layer = LinearNorm(
  194. hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, 1,
  195. bias=True, w_init_gain='sigmoid')
  196. def get_go_frame(self, memory):
  197. """ Gets all zeros frames to use as first decoder input
  198. PARAMS
  199. ------
  200. memory: decoder outputs
  201. RETURNS
  202. -------
  203. decoder_input: all zeros frames
  204. """
  205. B = memory.size(0)
  206. decoder_input = Variable(memory.data.new(
  207. B, self.n_mel_channels * self.n_frames_per_step).zero_())
  208. return decoder_input
  209. def initialize_decoder_states(self, memory, mask):
  210. """ Initializes attention rnn states, decoder rnn states, attention
  211. weights, attention cumulative weights, attention context, stores memory
  212. and stores processed memory
  213. PARAMS
  214. ------
  215. memory: Encoder outputs
  216. mask: Mask for padded data if training, expects None for inference
  217. """
  218. B = memory.size(0)
  219. MAX_TIME = memory.size(1)
  220. self.attention_hidden = Variable(memory.data.new(
  221. B, self.attention_rnn_dim).zero_())
  222. self.attention_cell = Variable(memory.data.new(
  223. B, self.attention_rnn_dim).zero_())
  224. self.decoder_hidden = Variable(memory.data.new(
  225. B, self.decoder_rnn_dim).zero_())
  226. self.decoder_cell = Variable(memory.data.new(
  227. B, self.decoder_rnn_dim).zero_())
  228. self.attention_weights = Variable(memory.data.new(
  229. B, MAX_TIME).zero_())
  230. self.attention_weights_cum = Variable(memory.data.new(
  231. B, MAX_TIME).zero_())
  232. self.attention_context = Variable(memory.data.new(
  233. B, self.encoder_embedding_dim).zero_())
  234. self.memory = memory
  235. self.processed_memory = self.attention_layer.memory_layer(memory)
  236. self.mask = mask
  237. def parse_decoder_inputs(self, decoder_inputs):
  238. """ Prepares decoder inputs, i.e. mel outputs
  239. PARAMS
  240. ------
  241. decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs
  242. RETURNS
  243. -------
  244. inputs: processed decoder inputs
  245. """
  246. # (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels)
  247. decoder_inputs = decoder_inputs.transpose(1, 2)
  248. decoder_inputs = decoder_inputs.view(
  249. decoder_inputs.size(0),
  250. int(decoder_inputs.size(1)/self.n_frames_per_step), -1)
  251. # (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels)
  252. decoder_inputs = decoder_inputs.transpose(0, 1)
  253. return decoder_inputs
  254. def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments):
  255. """ Prepares decoder outputs for output
  256. PARAMS
  257. ------
  258. mel_outputs:
  259. gate_outputs: gate output energies
  260. alignments:
  261. RETURNS
  262. -------
  263. mel_outputs:
  264. gate_outpust: gate output energies
  265. alignments:
  266. """
  267. # (T_out, B) -> (B, T_out)
  268. alignments = torch.stack(alignments).transpose(0, 1)
  269. # (T_out, B) -> (B, T_out)
  270. gate_outputs = torch.stack(gate_outputs).transpose(0, 1)
  271. gate_outputs = gate_outputs.contiguous()
  272. # (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels)
  273. mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
  274. # decouple frames per step
  275. mel_outputs = mel_outputs.view(
  276. mel_outputs.size(0), -1, self.n_mel_channels)
  277. # (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out)
  278. mel_outputs = mel_outputs.transpose(1, 2)
  279. return mel_outputs, gate_outputs, alignments
  280. def decode(self, decoder_input):
  281. """ Decoder step using stored states, attention and memory
  282. PARAMS
  283. ------
  284. decoder_input: previous mel output
  285. RETURNS
  286. -------
  287. mel_output:
  288. gate_output: gate output energies
  289. attention_weights:
  290. """
  291. decoder_input = self.prenet(decoder_input)
  292. cell_input = torch.cat((decoder_input, self.attention_context), -1)
  293. self.attention_hidden, self.attention_cell = self.attention_rnn(
  294. cell_input, (self.attention_hidden, self.attention_cell))
  295. attention_weights_cat = torch.cat(
  296. (self.attention_weights.unsqueeze(1),
  297. self.attention_weights_cum.unsqueeze(1)), dim=1)
  298. self.attention_context, self.attention_weights = self.attention_layer(
  299. self.attention_hidden, self.memory, self.processed_memory,
  300. attention_weights_cat, self.mask)
  301. self.attention_weights_cum += self.attention_weights
  302. decoder_input = torch.cat(
  303. (self.attention_hidden, self.attention_context), -1)
  304. self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
  305. decoder_input, (self.decoder_hidden, self.decoder_cell))
  306. decoder_hidden_attention_context = torch.cat(
  307. (self.decoder_hidden, self.attention_context), dim=1)
  308. decoder_output = self.linear_projection(
  309. decoder_hidden_attention_context)
  310. gate_prediction = self.gate_layer(decoder_hidden_attention_context)
  311. return decoder_output, gate_prediction, self.attention_weights
  312. def forward(self, memory, decoder_inputs, memory_lengths):
  313. """ Decoder forward pass for training
  314. PARAMS
  315. ------
  316. memory: Encoder outputs
  317. decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
  318. memory_lengths: Encoder output lengths for attention masking.
  319. RETURNS
  320. -------
  321. mel_outputs: mel outputs from the decoder
  322. gate_outputs: gate outputs from the decoder
  323. alignments: sequence of attention weights from the decoder
  324. """
  325. decoder_input = self.get_go_frame(memory)
  326. decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
  327. self.initialize_decoder_states(
  328. memory, mask=~get_mask_from_lengths(memory_lengths))
  329. mel_outputs, gate_outputs, alignments = [], [], []
  330. while len(mel_outputs) < decoder_inputs.size(0):
  331. mel_output, gate_output, attention_weights = self.decode(
  332. decoder_input)
  333. mel_outputs += [mel_output.squeeze(1)]
  334. gate_outputs += [gate_output.squeeze()]
  335. alignments += [attention_weights]
  336. decoder_input = decoder_inputs[len(mel_outputs) - 1]
  337. mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
  338. mel_outputs, gate_outputs, alignments)
  339. return mel_outputs, gate_outputs, alignments
  340. def inference(self, memory):
  341. """ Decoder inference
  342. PARAMS
  343. ------
  344. memory: Encoder outputs
  345. RETURNS
  346. -------
  347. mel_outputs: mel outputs from the decoder
  348. gate_outputs: gate outputs from the decoder
  349. alignments: sequence of attention weights from the decoder
  350. """
  351. decoder_input = self.get_go_frame(memory)
  352. self.initialize_decoder_states(memory, mask=None)
  353. mel_outputs, gate_outputs, alignments = [], [], []
  354. while True:
  355. mel_output, gate_output, alignment = self.decode(decoder_input)
  356. mel_outputs += [mel_output.squeeze(1)]
  357. gate_outputs += [gate_output.squeeze()]
  358. alignments += [alignment]
  359. if F.sigmoid(gate_output.data) > self.gate_threshold:
  360. break
  361. elif len(mel_outputs) == self.max_decoder_steps:
  362. print("Warning! Reached max decoder steps")
  363. break
  364. decoder_input = mel_output
  365. mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
  366. mel_outputs, gate_outputs, alignments)
  367. return mel_outputs, gate_outputs, alignments
  368. class Tacotron2(nn.Module):
  369. def __init__(self, hparams):
  370. super(Tacotron2, self).__init__()
  371. self.mask_padding = hparams.mask_padding
  372. self.fp16_run = hparams.fp16_run
  373. self.n_mel_channels = hparams.n_mel_channels
  374. self.n_frames_per_step = hparams.n_frames_per_step
  375. self.embedding = nn.Embedding(
  376. hparams.n_symbols, hparams.symbols_embedding_dim)
  377. self.encoder = Encoder(hparams)
  378. self.decoder = Decoder(hparams)
  379. self.postnet = Postnet(hparams)
  380. def parse_batch(self, batch):
  381. text_padded, input_lengths, mel_padded, gate_padded, \
  382. output_lengths = batch
  383. text_padded = to_gpu(text_padded).long()
  384. input_lengths = to_gpu(input_lengths).long()
  385. max_len = torch.max(input_lengths.data).item()
  386. mel_padded = to_gpu(mel_padded).float()
  387. gate_padded = to_gpu(gate_padded).float()
  388. output_lengths = to_gpu(output_lengths).long()
  389. return (
  390. (text_padded, input_lengths, mel_padded, max_len, output_lengths),
  391. (mel_padded, gate_padded))
  392. def parse_input(self, inputs):
  393. inputs = fp32_to_fp16(inputs) if self.fp16_run else inputs
  394. return inputs
  395. def parse_output(self, outputs, output_lengths=None):
  396. if self.mask_padding and output_lengths is not None:
  397. mask = ~get_mask_from_lengths(output_lengths+1) # +1 <stop> token
  398. mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
  399. mask = mask.permute(1, 0, 2)
  400. outputs[0].data.masked_fill_(mask, 0.0)
  401. outputs[1].data.masked_fill_(mask, 0.0)
  402. outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies
  403. outputs = fp16_to_fp32(outputs) if self.fp16_run else outputs
  404. return outputs
  405. def forward(self, inputs):
  406. inputs, input_lengths, targets, max_len, \
  407. output_lengths = self.parse_input(inputs)
  408. input_lengths, output_lengths = input_lengths.data, output_lengths.data
  409. embedded_inputs = self.embedding(inputs).transpose(1, 2)
  410. encoder_outputs = self.encoder(embedded_inputs, input_lengths)
  411. mel_outputs, gate_outputs, alignments = self.decoder(
  412. encoder_outputs, targets, memory_lengths=input_lengths)
  413. mel_outputs_postnet = self.postnet(mel_outputs)
  414. mel_outputs_postnet = mel_outputs + mel_outputs_postnet
  415. # DataParallel expects equal sized inputs/outputs, hence padding
  416. if input_lengths is not None:
  417. alignments = alignments.unsqueeze(0)
  418. alignments = nn.functional.pad(
  419. alignments,
  420. (0, max_len - alignments.size(3), 0, 0),
  421. "constant", 0)
  422. alignments = alignments.squeeze()
  423. return self.parse_output(
  424. [mel_outputs, mel_outputs_postnet, gate_outputs, alignments],
  425. output_lengths)
  426. def inference(self, inputs):
  427. inputs = self.parse_input(inputs)
  428. embedded_inputs = self.embedding(inputs).transpose(1, 2)
  429. encoder_outputs = self.encoder.inference(embedded_inputs)
  430. mel_outputs, gate_outputs, alignments = self.decoder.inference(
  431. encoder_outputs)
  432. mel_outputs_postnet = self.postnet(mel_outputs)
  433. mel_outputs_postnet = mel_outputs + mel_outputs_postnet
  434. outputs = self.parse_output(
  435. [mel_outputs, mel_outputs_postnet, gate_outputs, alignments])
  436. return outputs