|
|
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import numpy as np
-
-
- class CNNIntent(nn.Module):
- def __init__(self, input_dim, embedding_dim, output_dim, filter_sizes, kernel_size, dropout, wordvecs=None):
- super().__init__()
-
- if wordvecs is not None:
- self.embedding = nn.Embedding.from_pretrained(wordvecs)
- else:
- self.embedding = nn.Embedding(input_dim, embedding_dim)
-
- self.convs = nn.ModuleList(
- [nn.Conv1d(filter_sizes[i - 1] if i > 0 else embedding_dim, filter_sizes[i], kernel_size) for i in range(len(filter_sizes))]
- )
- self.dropout = nn.Dropout(dropout)
- self.fc = nn.Linear(filter_sizes[-1], output_dim)
-
- self.embedding_dim = embedding_dim
- self.filter_sizes = filter_sizes
- self.kernel_size = kernel_size
- self.unpruned_count = sum(filter_sizes)
-
- def forward(self, query): # query shape: [batch, seq len]
- x = self.embedding(query).permute(0, 2, 1) # [batch, embedding dim, seq len]
- for conv in self.convs:
- x = conv(x)
- x = torch.rrelu(x)
- x = x.permute(0, 2, 1)
- x, _ = torch.max(x, dim=1)
- return self.fc(self.dropout(x))
-
- def prune(self, count, norm=2):
- if not (sum(self.filter_sizes) - count > 0): # ensure we will have > 0 filters over
- exit(0)
-
- rankings = [] # list of (conv #, filter #, norm)
- for i, conv in enumerate(self.convs):
- for k, filter in enumerate(conv.weight):
- rankings.append((i, k, torch.norm(filter.view(-1), p=norm, dim=0).item()))
- rankings.sort(key = lambda x: x[2])
-
- for ranking in rankings[:count]:
- conv_num, filter_num, _ = ranking
-
- # remove filter
- new_weight = torch.cat((self.convs[conv_num].weight[:filter_num],
- self.convs[conv_num].weight[filter_num + 1:]))
- new_bias = torch.cat((self.convs[conv_num].bias[:filter_num],
- self.convs[conv_num].bias[filter_num + 1:]))
-
- self.convs[conv_num] = nn.Conv1d(self.filter_sizes[conv_num - 1] if conv_num > 0 else self.embedding_dim,
- self.filter_sizes[conv_num] - 1,
- self.kernel_size)
- self.convs[conv_num].weight = nn.Parameter(new_weight)
- self.convs[conv_num].bias = nn.Parameter(new_bias)
-
- # update channel in succeeding layer
- if conv_num == len(self.filter_sizes) - 1: # prune linear
- new_weight = torch.cat((self.fc.weight[:,:filter_num], self.fc.weight[:,filter_num + 1:]), dim=1)
- new_bias = self.fc.bias
-
- self.fc = nn.Linear(self.fc.in_features - 1, self.fc.out_features)
- self.fc.weight = nn.Parameter(new_weight)
- self.fc.bias = nn.Parameter(new_bias)
- else: # prune conv
- new_weight = torch.cat((self.convs[conv_num + 1].weight[:,:filter_num], self.convs[conv_num + 1].weight[:,filter_num + 1:]), dim=1)
- new_bias = self.convs[conv_num + 1].bias
-
- self.convs[conv_num + 1] = nn.Conv1d(self.filter_sizes[conv_num] - 1,
- self.filter_sizes[conv_num + 1],
- self.kernel_size)
-
- self.convs[conv_num + 1].weight = nn.Parameter(new_weight)
- self.convs[conv_num + 1].bias = nn.Parameter(new_bias)
-
- self.filter_sizes = tuple([filter_size - 1 if i == conv_num else filter_size for i, filter_size in enumerate(self.filter_sizes)])
-
-
- class CNNSlot(nn.Module):
- def __init__(self, input_dim, embedding_dim, output_dim, filter_sizes, kernel_size, dropout, wordvecs=None):
- super().__init__()
-
- if wordvecs is not None:
- self.embedding = nn.Embedding.from_pretrained(wordvecs)
- else:
- self.embedding = nn.Embedding(input_dim, embedding_dim)
-
- self.convs = nn.ModuleList(
- [nn.Conv1d(filter_sizes[i - 1] if i > 0 else embedding_dim, filter_sizes[i], kernel_size) for i in range(len(filter_sizes))]
- )
- self.dropout = nn.Dropout(dropout)
- self.fc = nn.Linear(filter_sizes[-1], output_dim)
-
- self.padding = int((kernel_size - 1) / 2)
-
- self.embedding_dim = embedding_dim
- self.unpruned_count = sum(filter_sizes)
- self.filter_sizes = filter_sizes
- self.kernel_size = kernel_size
-
- def forward(self, query): # query shape: [batch, seq len]
- x = self.embedding(query) # embedded shape: [batch, seq len, embedding dim]
- x = x.permute(0, 2, 1) # x shape: [batch, embedding dim, seq len]
- for conv in self.convs:
- x = F.pad(x, (self.padding, self.padding)) # x shape: [batch, filter count, seq len]
- x = conv(x)
- x = torch.rrelu(x)
- x = x.permute(0, 2, 1) # x shape: [batch, seq len, filter count]
- x = self.fc(self.dropout(x))
- return x
-
- def prune(self, count, norm=2):
- if not (sum(self.filter_sizes) - count > 0): # ensure we will have > 0 filters over
- exit(0)
-
- rankings = [] # list of (conv #, filter #, norm)
- for i, conv in enumerate(self.convs):
- for k, filter in enumerate(conv.weight):
- rankings.append((i, k, torch.norm(filter.view(-1), p=norm, dim=0).item()))
- rankings.sort(key = lambda x: x[2])
-
- for ranking in rankings[:count]:
- conv_num, filter_num, _ = ranking
-
- # remove filter
- new_weight = torch.cat((self.convs[conv_num].weight[:filter_num],
- self.convs[conv_num].weight[filter_num + 1:]))
- new_bias = torch.cat((self.convs[conv_num].bias[:filter_num],
- self.convs[conv_num].bias[filter_num + 1:]))
-
- self.convs[conv_num] = nn.Conv1d(self.filter_sizes[conv_num - 1] if conv_num > 0 else self.embedding_dim,
- self.filter_sizes[conv_num] - 1,
- self.kernel_size)
- self.convs[conv_num].weight = nn.Parameter(new_weight)
- self.convs[conv_num].bias = nn.Parameter(new_bias)
-
- # update channel in succeeding layer
- if conv_num == len(self.filter_sizes) - 1: # prune linear
- new_weight = torch.cat((self.fc.weight[:,:filter_num], self.fc.weight[:,filter_num + 1:]), dim=1)
- new_bias = self.fc.bias
-
- self.fc = nn.Linear(self.fc.in_features - 1, self.fc.out_features)
- self.fc.weight = nn.Parameter(new_weight)
- self.fc.bias = nn.Parameter(new_bias)
- else: # prune conv
- new_weight = torch.cat((self.convs[conv_num + 1].weight[:,:filter_num], self.convs[conv_num + 1].weight[:,filter_num + 1:]), dim=1)
- new_bias = self.convs[conv_num + 1].bias
-
- self.convs[conv_num + 1] = nn.Conv1d(self.filter_sizes[conv_num] - 1,
- self.filter_sizes[conv_num + 1],
- self.kernel_size)
-
- self.convs[conv_num + 1].weight = nn.Parameter(new_weight)
- self.convs[conv_num + 1].bias = nn.Parameter(new_bias)
-
- self.filter_sizes = tuple([filter_size - 1 if i == conv_num else filter_size for i, filter_size in enumerate(self.filter_sizes)])
-
-
- class CNNJoint(nn.Module):
- def __init__(self, input_dim, embedding_dim, intent_dim, slot_dim, filter_sizes, kernel_size, dropout, wordvecs=None):
- super().__init__()
-
- if wordvecs is not None:
- self.embedding = nn.Embedding.from_pretrained(wordvecs)
- else:
- self.embedding = nn.Embedding(input_dim, embedding_dim)
-
- self.convs = nn.ModuleList(
- [nn.Conv1d(filter_sizes[i - 1] if i > 0 else embedding_dim, filter_sizes[i], kernel_size) for i in range(len(filter_sizes))]
- )
-
- self.intent_dropout = nn.Dropout(dropout)
- self.intent_fc = nn.Linear(filter_sizes[-1], intent_dim)
- self.slot_dropout = nn.Dropout(dropout)
- self.slot_fc = nn.Linear(filter_sizes[-1], slot_dim)
-
- self.padding = int((kernel_size - 1) / 2)
- self.unpruned_count = sum(filter_sizes)
- self.embedding_dim = embedding_dim
- self.filter_sizes = filter_sizes
- self.kernel_size = kernel_size
-
- def forward(self, query):
- x = self.embedding(query).permute(0, 2, 1)
- for conv in self.convs:
- x = F.pad(x, (self.padding, self.padding))
- x = conv(x)
- x = torch.rrelu(x)
- x = x.permute(0, 2, 1)
-
- intent_pred = self.intent_fc(self.intent_dropout(torch.max(x, dim=1)[0]))
- slot_pred = self.slot_fc(self.slot_dropout(x))
-
- return intent_pred, slot_pred.permute(0, 2, 1)
-
- def prune(self, count, norm=2):
- if not (sum(self.filter_sizes) - count > 0): # ensure we will have > 0 filters over
- exit(0)
-
- rankings = [] # list of (conv #, filter #, norm)
- for i, conv in enumerate(self.convs):
- for k, filter in enumerate(conv.weight):
- rankings.append((i, k, torch.norm(filter.view(-1), p=norm, dim=0).item()))
- rankings.sort(key = lambda x: x[2])
-
- for ranking in rankings[:count]:
- conv_num, filter_num, _ = ranking
-
- # remove filter
- new_weight = torch.cat((self.convs[conv_num].weight[:filter_num],
- self.convs[conv_num].weight[filter_num + 1:]))
- new_bias = torch.cat((self.convs[conv_num].bias[:filter_num],
- self.convs[conv_num].bias[filter_num + 1:]))
-
- self.convs[conv_num] = nn.Conv1d(self.filter_sizes[conv_num - 1] if conv_num > 0 else self.embedding_dim,
- self.filter_sizes[conv_num] - 1,
- self.kernel_size)
- self.convs[conv_num].weight = nn.Parameter(new_weight)
- self.convs[conv_num].bias = nn.Parameter(new_bias)
-
- # update channel in succeeding layer
- if conv_num == len(self.filter_sizes) - 1: # prune linear
- new_intent_weight = torch.cat((self.intent_fc.weight[:,:filter_num], self.intent_fc.weight[:,filter_num + 1:]), dim=1)
- new_intent_bias = self.intent_fc.bias
-
- self.intent_fc = nn.Linear(self.intent_fc.in_features - 1, self.intent_fc.out_features)
- self.intent_fc.weight = nn.Parameter(new_intent_weight)
- self.intent_fc.bias = nn.Parameter(new_intent_bias)
-
- new_slot_weight = torch.cat((self.slot_fc.weight[:,:filter_num], self.slot_fc.weight[:,filter_num + 1:]), dim=1)
- new_slot_bias = self.slot_fc.bias
-
- self.slot_fc = nn.Linear(self.slot_fc.in_features - 1, self.slot_fc.out_features)
- self.slot_fc.weight = nn.Parameter(new_slot_weight)
- self.slot_fc.bias = nn.Parameter(new_slot_bias)
- else: # prune conv
- new_weight = torch.cat((self.convs[conv_num + 1].weight[:,:filter_num], self.convs[conv_num + 1].weight[:,filter_num + 1:]), dim=1)
- new_bias = self.convs[conv_num + 1].bias
-
- self.convs[conv_num + 1] = nn.Conv1d(self.filter_sizes[conv_num] - 1,
- self.filter_sizes[conv_num + 1],
- self.kernel_size)
-
- self.convs[conv_num + 1].weight = nn.Parameter(new_weight)
- self.convs[conv_num + 1].bias = nn.Parameter(new_bias)
-
- self.filter_sizes = tuple([filter_size - 1 if i == conv_num else filter_size for i, filter_size in enumerate(self.filter_sizes)])
|