You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

251 lines
12 KiB

4 years ago
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import numpy as np
  5. class CNNIntent(nn.Module):
  6. def __init__(self, input_dim, embedding_dim, output_dim, filter_sizes, kernel_size, dropout, wordvecs=None):
  7. super().__init__()
  8. if wordvecs is not None:
  9. self.embedding = nn.Embedding.from_pretrained(wordvecs)
  10. else:
  11. self.embedding = nn.Embedding(input_dim, embedding_dim)
  12. self.convs = nn.ModuleList(
  13. [nn.Conv1d(filter_sizes[i - 1] if i > 0 else embedding_dim, filter_sizes[i], kernel_size) for i in range(len(filter_sizes))]
  14. )
  15. self.dropout = nn.Dropout(dropout)
  16. self.fc = nn.Linear(filter_sizes[-1], output_dim)
  17. self.embedding_dim = embedding_dim
  18. self.filter_sizes = filter_sizes
  19. self.kernel_size = kernel_size
  20. self.unpruned_count = sum(filter_sizes)
  21. def forward(self, query): # query shape: [batch, seq len]
  22. x = self.embedding(query).permute(0, 2, 1) # [batch, embedding dim, seq len]
  23. for conv in self.convs:
  24. x = conv(x)
  25. x = torch.rrelu(x)
  26. x = x.permute(0, 2, 1)
  27. x, _ = torch.max(x, dim=1)
  28. return self.fc(self.dropout(x))
  29. def prune(self, count, norm=2):
  30. if not (sum(self.filter_sizes) - count > 0): # ensure we will have > 0 filters over
  31. exit(0)
  32. rankings = [] # list of (conv #, filter #, norm)
  33. for i, conv in enumerate(self.convs):
  34. for k, filter in enumerate(conv.weight):
  35. rankings.append((i, k, torch.norm(filter.view(-1), p=norm, dim=0).item()))
  36. rankings.sort(key = lambda x: x[2])
  37. for ranking in rankings[:count]:
  38. conv_num, filter_num, _ = ranking
  39. # remove filter
  40. new_weight = torch.cat((self.convs[conv_num].weight[:filter_num],
  41. self.convs[conv_num].weight[filter_num + 1:]))
  42. new_bias = torch.cat((self.convs[conv_num].bias[:filter_num],
  43. self.convs[conv_num].bias[filter_num + 1:]))
  44. self.convs[conv_num] = nn.Conv1d(self.filter_sizes[conv_num - 1] if conv_num > 0 else self.embedding_dim,
  45. self.filter_sizes[conv_num] - 1,
  46. self.kernel_size)
  47. self.convs[conv_num].weight = nn.Parameter(new_weight)
  48. self.convs[conv_num].bias = nn.Parameter(new_bias)
  49. # update channel in succeeding layer
  50. if conv_num == len(self.filter_sizes) - 1: # prune linear
  51. new_weight = torch.cat((self.fc.weight[:,:filter_num], self.fc.weight[:,filter_num + 1:]), dim=1)
  52. new_bias = self.fc.bias
  53. self.fc = nn.Linear(self.fc.in_features - 1, self.fc.out_features)
  54. self.fc.weight = nn.Parameter(new_weight)
  55. self.fc.bias = nn.Parameter(new_bias)
  56. else: # prune conv
  57. new_weight = torch.cat((self.convs[conv_num + 1].weight[:,:filter_num], self.convs[conv_num + 1].weight[:,filter_num + 1:]), dim=1)
  58. new_bias = self.convs[conv_num + 1].bias
  59. self.convs[conv_num + 1] = nn.Conv1d(self.filter_sizes[conv_num] - 1,
  60. self.filter_sizes[conv_num + 1],
  61. self.kernel_size)
  62. self.convs[conv_num + 1].weight = nn.Parameter(new_weight)
  63. self.convs[conv_num + 1].bias = nn.Parameter(new_bias)
  64. self.filter_sizes = tuple([filter_size - 1 if i == conv_num else filter_size for i, filter_size in enumerate(self.filter_sizes)])
  65. class CNNSlot(nn.Module):
  66. def __init__(self, input_dim, embedding_dim, output_dim, filter_sizes, kernel_size, dropout, wordvecs=None):
  67. super().__init__()
  68. if wordvecs is not None:
  69. self.embedding = nn.Embedding.from_pretrained(wordvecs)
  70. else:
  71. self.embedding = nn.Embedding(input_dim, embedding_dim)
  72. self.convs = nn.ModuleList(
  73. [nn.Conv1d(filter_sizes[i - 1] if i > 0 else embedding_dim, filter_sizes[i], kernel_size) for i in range(len(filter_sizes))]
  74. )
  75. self.dropout = nn.Dropout(dropout)
  76. self.fc = nn.Linear(filter_sizes[-1], output_dim)
  77. self.padding = int((kernel_size - 1) / 2)
  78. self.embedding_dim = embedding_dim
  79. self.unpruned_count = sum(filter_sizes)
  80. self.filter_sizes = filter_sizes
  81. self.kernel_size = kernel_size
  82. def forward(self, query): # query shape: [batch, seq len]
  83. x = self.embedding(query) # embedded shape: [batch, seq len, embedding dim]
  84. x = x.permute(0, 2, 1) # x shape: [batch, embedding dim, seq len]
  85. for conv in self.convs:
  86. x = F.pad(x, (self.padding, self.padding)) # x shape: [batch, filter count, seq len]
  87. x = conv(x)
  88. x = torch.rrelu(x)
  89. x = x.permute(0, 2, 1) # x shape: [batch, seq len, filter count]
  90. x = self.fc(self.dropout(x))
  91. return x
  92. def prune(self, count, norm=2):
  93. if not (sum(self.filter_sizes) - count > 0): # ensure we will have > 0 filters over
  94. exit(0)
  95. rankings = [] # list of (conv #, filter #, norm)
  96. for i, conv in enumerate(self.convs):
  97. for k, filter in enumerate(conv.weight):
  98. rankings.append((i, k, torch.norm(filter.view(-1), p=norm, dim=0).item()))
  99. rankings.sort(key = lambda x: x[2])
  100. for ranking in rankings[:count]:
  101. conv_num, filter_num, _ = ranking
  102. # remove filter
  103. new_weight = torch.cat((self.convs[conv_num].weight[:filter_num],
  104. self.convs[conv_num].weight[filter_num + 1:]))
  105. new_bias = torch.cat((self.convs[conv_num].bias[:filter_num],
  106. self.convs[conv_num].bias[filter_num + 1:]))
  107. self.convs[conv_num] = nn.Conv1d(self.filter_sizes[conv_num - 1] if conv_num > 0 else self.embedding_dim,
  108. self.filter_sizes[conv_num] - 1,
  109. self.kernel_size)
  110. self.convs[conv_num].weight = nn.Parameter(new_weight)
  111. self.convs[conv_num].bias = nn.Parameter(new_bias)
  112. # update channel in succeeding layer
  113. if conv_num == len(self.filter_sizes) - 1: # prune linear
  114. new_weight = torch.cat((self.fc.weight[:,:filter_num], self.fc.weight[:,filter_num + 1:]), dim=1)
  115. new_bias = self.fc.bias
  116. self.fc = nn.Linear(self.fc.in_features - 1, self.fc.out_features)
  117. self.fc.weight = nn.Parameter(new_weight)
  118. self.fc.bias = nn.Parameter(new_bias)
  119. else: # prune conv
  120. new_weight = torch.cat((self.convs[conv_num + 1].weight[:,:filter_num], self.convs[conv_num + 1].weight[:,filter_num + 1:]), dim=1)
  121. new_bias = self.convs[conv_num + 1].bias
  122. self.convs[conv_num + 1] = nn.Conv1d(self.filter_sizes[conv_num] - 1,
  123. self.filter_sizes[conv_num + 1],
  124. self.kernel_size)
  125. self.convs[conv_num + 1].weight = nn.Parameter(new_weight)
  126. self.convs[conv_num + 1].bias = nn.Parameter(new_bias)
  127. self.filter_sizes = tuple([filter_size - 1 if i == conv_num else filter_size for i, filter_size in enumerate(self.filter_sizes)])
  128. class CNNJoint(nn.Module):
  129. def __init__(self, input_dim, embedding_dim, intent_dim, slot_dim, filter_sizes, kernel_size, dropout, wordvecs=None):
  130. super().__init__()
  131. if wordvecs is not None:
  132. self.embedding = nn.Embedding.from_pretrained(wordvecs)
  133. else:
  134. self.embedding = nn.Embedding(input_dim, embedding_dim)
  135. self.convs = nn.ModuleList(
  136. [nn.Conv1d(filter_sizes[i - 1] if i > 0 else embedding_dim, filter_sizes[i], kernel_size) for i in range(len(filter_sizes))]
  137. )
  138. self.intent_dropout = nn.Dropout(dropout)
  139. self.intent_fc = nn.Linear(filter_sizes[-1], intent_dim)
  140. self.slot_dropout = nn.Dropout(dropout)
  141. self.slot_fc = nn.Linear(filter_sizes[-1], slot_dim)
  142. self.padding = int((kernel_size - 1) / 2)
  143. self.unpruned_count = sum(filter_sizes)
  144. self.embedding_dim = embedding_dim
  145. self.filter_sizes = filter_sizes
  146. self.kernel_size = kernel_size
  147. def forward(self, query):
  148. x = self.embedding(query).permute(0, 2, 1)
  149. for conv in self.convs:
  150. x = F.pad(x, (self.padding, self.padding))
  151. x = conv(x)
  152. x = torch.rrelu(x)
  153. x = x.permute(0, 2, 1)
  154. intent_pred = self.intent_fc(self.intent_dropout(torch.max(x, dim=1)[0]))
  155. slot_pred = self.slot_fc(self.slot_dropout(x))
  156. return intent_pred, slot_pred.permute(0, 2, 1)
  157. def prune(self, count, norm=2):
  158. if not (sum(self.filter_sizes) - count > 0): # ensure we will have > 0 filters over
  159. exit(0)
  160. rankings = [] # list of (conv #, filter #, norm)
  161. for i, conv in enumerate(self.convs):
  162. for k, filter in enumerate(conv.weight):
  163. rankings.append((i, k, torch.norm(filter.view(-1), p=norm, dim=0).item()))
  164. rankings.sort(key = lambda x: x[2])
  165. for ranking in rankings[:count]:
  166. conv_num, filter_num, _ = ranking
  167. # remove filter
  168. new_weight = torch.cat((self.convs[conv_num].weight[:filter_num],
  169. self.convs[conv_num].weight[filter_num + 1:]))
  170. new_bias = torch.cat((self.convs[conv_num].bias[:filter_num],
  171. self.convs[conv_num].bias[filter_num + 1:]))
  172. self.convs[conv_num] = nn.Conv1d(self.filter_sizes[conv_num - 1] if conv_num > 0 else self.embedding_dim,
  173. self.filter_sizes[conv_num] - 1,
  174. self.kernel_size)
  175. self.convs[conv_num].weight = nn.Parameter(new_weight)
  176. self.convs[conv_num].bias = nn.Parameter(new_bias)
  177. # update channel in succeeding layer
  178. if conv_num == len(self.filter_sizes) - 1: # prune linear
  179. new_intent_weight = torch.cat((self.intent_fc.weight[:,:filter_num], self.intent_fc.weight[:,filter_num + 1:]), dim=1)
  180. new_intent_bias = self.intent_fc.bias
  181. self.intent_fc = nn.Linear(self.intent_fc.in_features - 1, self.intent_fc.out_features)
  182. self.intent_fc.weight = nn.Parameter(new_intent_weight)
  183. self.intent_fc.bias = nn.Parameter(new_intent_bias)
  184. new_slot_weight = torch.cat((self.slot_fc.weight[:,:filter_num], self.slot_fc.weight[:,filter_num + 1:]), dim=1)
  185. new_slot_bias = self.slot_fc.bias
  186. self.slot_fc = nn.Linear(self.slot_fc.in_features - 1, self.slot_fc.out_features)
  187. self.slot_fc.weight = nn.Parameter(new_slot_weight)
  188. self.slot_fc.bias = nn.Parameter(new_slot_bias)
  189. else: # prune conv
  190. new_weight = torch.cat((self.convs[conv_num + 1].weight[:,:filter_num], self.convs[conv_num + 1].weight[:,filter_num + 1:]), dim=1)
  191. new_bias = self.convs[conv_num + 1].bias
  192. self.convs[conv_num + 1] = nn.Conv1d(self.filter_sizes[conv_num] - 1,
  193. self.filter_sizes[conv_num + 1],
  194. self.kernel_size)
  195. self.convs[conv_num + 1].weight = nn.Parameter(new_weight)
  196. self.convs[conv_num + 1].bias = nn.Parameter(new_bias)
  197. self.filter_sizes = tuple([filter_size - 1 if i == conv_num else filter_size for i, filter_size in enumerate(self.filter_sizes)])