Browse Source

layers.py: rewrite

master
rafaelvalle 6 years ago
parent
commit
1ec0e5e8cd
1 changed files with 3 additions and 3 deletions
  1. +3
    -3
      layers.py

+ 3
- 3
layers.py View File

@ -10,7 +10,7 @@ class LinearNorm(torch.nn.Module):
super(LinearNorm, self).__init__() super(LinearNorm, self).__init__()
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
torch.nn.init.xavier_uniform(
torch.nn.init.xavier_uniform_(
self.linear_layer.weight, self.linear_layer.weight,
gain=torch.nn.init.calculate_gain(w_init_gain)) gain=torch.nn.init.calculate_gain(w_init_gain))
@ -31,7 +31,7 @@ class ConvNorm(torch.nn.Module):
padding=padding, dilation=dilation, padding=padding, dilation=dilation,
bias=bias) bias=bias)
torch.nn.init.xavier_uniform(
torch.nn.init.xavier_uniform_(
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
def forward(self, signal): def forward(self, signal):
@ -42,7 +42,7 @@ class ConvNorm(torch.nn.Module):
class TacotronSTFT(torch.nn.Module): class TacotronSTFT(torch.nn.Module):
def __init__(self, filter_length=1024, hop_length=256, win_length=1024, def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0, n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
mel_fmax=None):
mel_fmax=8000.0):
super(TacotronSTFT, self).__init__() super(TacotronSTFT, self).__init__()
self.n_mel_channels = n_mel_channels self.n_mel_channels = n_mel_channels
self.sampling_rate = sampling_rate self.sampling_rate = sampling_rate

Loading…
Cancel
Save