You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

145 lines
6.3 KiB

4 years ago
  1. import time
  2. import copy
  3. import argparse
  4. import torch
  5. import torch.nn as nn
  6. import dataset
  7. import util
  8. import models
  9. from itertools import chain
  10. if __name__ == "__main__":
  11. parser = argparse.ArgumentParser()
  12. parser.add_argument('--name', required=True)
  13. parser.add_argument('--filename', required=True)
  14. parser.add_argument('--epochs', default=50)
  15. parser.add_argument('--seed', type=int, default=None)
  16. parser.add_argument('--patience', type=int, default=5)
  17. parser.add_argument('--dropout', type=float, default=0.5)
  18. parser.add_argument('--alpha', type=float, default=0.2)
  19. args = parser.parse_args()
  20. if 'atis' in args.name:
  21. args.dataset = 'atis'
  22. elif 'snips' in args.name:
  23. args.dataset = 'snips'
  24. if 'intent' in args.name:
  25. args.model = 'intent'
  26. elif 'slot' in args.name:
  27. args.model = 'slot'
  28. elif 'joint' in args.name:
  29. args.model = 'joint'
  30. print(f"seed {util.rep(args.seed)}")
  31. cuda = torch.cuda.is_available()
  32. train, valid, test, num_words, num_intent, num_slot, wordvecs = dataset.load(args.dataset, batch_size=8, seq_len=50)
  33. open(args.filename, 'w').close() # clear the file
  34. f = open(args.filename, "a")
  35. for filter_count in chain(range(300, 10, -5), range(10, 0, -1)):
  36. if args.model == 'intent':
  37. model = models.CNNIntent(num_words, 100, num_intent, (filter_count,), 5, args.dropout, wordvecs)
  38. elif args.model == 'slot':
  39. model = models.CNNSlot(num_words, 100, num_slot, (filter_count,), 5, args.dropout, wordvecs)
  40. elif args.model == 'joint':
  41. model = models.CNNJoint(num_words, 100, num_intent, num_slot, (filter_count,), 5, args.dropout, wordvecs)
  42. teacher = util.load_model(args.model, num_words, num_intent, num_slot, args.dropout, wordvecs)
  43. teacher.load_state_dict(torch.load(args.name))
  44. criterion = torch.nn.CrossEntropyLoss(ignore_index=-1)
  45. distill_criterion = nn.KLDivLoss(reduction='batchmean')
  46. optimizer = torch.optim.Adam(model.parameters())
  47. if cuda:
  48. model = model.cuda()
  49. teacher = teacher.cuda()
  50. best_valid_loss = float('inf')
  51. last_epoch_to_improve = 0
  52. best_model = model
  53. model_filename = f"models/{args.dataset}_{args.model}"
  54. if args.model == 'intent':
  55. for epoch in range(args.epochs):
  56. start_time = time.time()
  57. train_loss, train_acc = util.distill_intent(teacher, model, 1.0, train, distill_criterion, optimizer, cuda)
  58. valid_loss, valid_acc = util.valid_intent(model, valid, criterion, cuda)
  59. end_time = time.time()
  60. elapsed_time = end_time - start_time
  61. print(f"Epoch {epoch + 1:03} took {elapsed_time:.3f} seconds")
  62. print(f"\tTrain Loss: {train_loss:.5f}, Acc: {train_acc:.5f}")
  63. print(f"\tValid Loss: {valid_loss:.5f}, Acc: {valid_acc:.5f}")
  64. if valid_loss < best_valid_loss:
  65. last_epoch_to_improve = epoch
  66. best_valid_loss = valid_loss
  67. best_model = copy.deepcopy(model)
  68. print("\tNew best valid loss!")
  69. if last_epoch_to_improve + args.patience < epoch:
  70. break
  71. _, test_acc = util.valid_intent(best_model, test, criterion, cuda)
  72. print(f"Test Acc: {test_acc:.5f}")
  73. print(f"{sum(best_model.filter_sizes)}, {test_acc:.5f}", file=f, flush=True)
  74. elif args.model == 'slot':
  75. for epoch in range(args.epochs):
  76. start_time = time.time()
  77. train_loss, train_f1 = util.distill_slot(teacher, model, 1.0, train, distill_criterion, optimizer, cuda)
  78. valid_loss, valid_f1 = util.valid_slot(model, valid, criterion, cuda)
  79. end_time = time.time()
  80. elapsed_time = end_time - start_time
  81. print(f"Epoch {epoch + 1:03} took {elapsed_time:.3f} seconds")
  82. print(f"\tTrain Loss: {train_loss:.5f}, F1: {train_f1:.5f}")
  83. print(f"\tValid Loss: {valid_loss:.5f}, F1: {valid_f1:.5f}")
  84. if valid_loss < best_valid_loss:
  85. last_epoch_to_improve = epoch
  86. best_valid_loss = valid_loss
  87. best_model = copy.deepcopy(model)
  88. print("\tNew best valid loss!")
  89. if last_epoch_to_improve + args.patience < epoch:
  90. break
  91. _, test_f1 = util.valid_slot(best_model, test, criterion, cuda)
  92. print(f"Test F1: {test_f1:.5f}")
  93. print(f"{sum(best_model.filter_sizes)}, {test_f1:.5f}", file=f, flush=True)
  94. elif args.model == 'joint':
  95. for epoch in range(args.epochs):
  96. start_time = time.time()
  97. train_loss, (intent_train_loss, intent_train_acc), (slot_train_loss, slot_train_f1) = util.distill_joint(teacher, model, 1.0, train, distill_criterion, optimizer, cuda, args.alpha)
  98. valid_loss, (intent_valid_loss, intent_valid_acc), (slot_valid_loss, slot_valid_f1) = util.valid_joint(model, valid, criterion, cuda, args.alpha)
  99. end_time = time.time()
  100. elapsed_time = end_time - start_time
  101. print(f"Epoch {epoch + 1:03} took {elapsed_time:.3f} seconds")
  102. print(f"\tTrain Loss: {train_loss:.5f}, (Intent Loss: {intent_train_loss:.5f}, Acc: {intent_train_acc:.5f}), (Slot Loss: {slot_train_loss:.5f}, F1: {slot_train_f1:.5f})")
  103. print(f"\tValid Loss: {valid_loss:.5f}, (Intent Loss: {intent_valid_loss:.5f}, Acc: {intent_valid_acc:.5f}), (Slot Loss: {slot_valid_loss:.5f}, F1: {slot_valid_f1:.5f})")
  104. if valid_loss < best_valid_loss:
  105. last_epoch_to_improve = epoch
  106. best_valid_loss = valid_loss
  107. best_model = copy.deepcopy(model)
  108. print("\tNew best valid loss!")
  109. if last_epoch_to_improve + args.patience < epoch:
  110. break
  111. _, (_, intent_test_acc), (_, slot_test_f1) = util.valid_joint(best_model, test, criterion, cuda, args.alpha)
  112. print(f"Test Intent Acc: {intent_test_acc:.5f}, Slot F1: {slot_test_f1:.5f}")
  113. print(f"{sum(best_model.filter_sizes)}, {intent_test_acc:.5f}, {slot_test_f1:.5f}", file=f, flush=True)