You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

152 lines
5.6 KiB

4 years ago
  1. import time
  2. import copy
  3. import argparse
  4. import torch
  5. import torch.nn as nn
  6. import models
  7. import dataset
  8. import util
  9. if __name__ == "__main__":
  10. parser = argparse.ArgumentParser()
  11. parser.add_argument('--name', required=True)
  12. parser.add_argument('--zeroshot', action='store_true')
  13. parser.add_argument('--epochs', default=50)
  14. parser.add_argument('--seed', type=int, default=None)
  15. parser.add_argument('--patience', type=int, default=5)
  16. parser.add_argument('--dropout', type=float, default=0.5)
  17. parser.add_argument('--alpha', type=float, default=0.2)
  18. parser.add_argument('--l', type=int, default=2)
  19. parser.add_argument('--filename')
  20. args = parser.parse_args()
  21. if 'atis' in args.name:
  22. args.dataset = 'atis'
  23. elif 'snips' in args.name:
  24. args.dataset = 'snips'
  25. if 'intent' in args.name:
  26. args.model = 'intent'
  27. elif 'slot' in args.name:
  28. args.model = 'slot'
  29. elif 'joint' in args.name:
  30. args.model = 'joint'
  31. print(f"seed {util.rep(args.seed)}")
  32. cuda = torch.cuda.is_available()
  33. train, valid, test, num_words, num_intent, num_slot, wordvecs = dataset.load(args.dataset, batch_size=8, seq_len=50)
  34. model = util.load_model(args.model, num_words, num_intent, num_slot, args.dropout, wordvecs)
  35. model.load_state_dict(torch.load(args.name))
  36. criterion = torch.nn.CrossEntropyLoss(ignore_index=-1)
  37. if cuda:
  38. model = model.cuda()
  39. if len(args.name.split('/')) > 1:
  40. nameprefix = args.name.split('/')[-1]
  41. else:
  42. nameprefix = args.name
  43. filename = args.filename if args.filename else f"results/{nameprefix}_{'zeroshot' if args.zeroshot else 'retrain'}_l{args.l}_alpha{args.alpha}.csv"
  44. if args.model == 'intent':
  45. open(filename, 'w').close() # clear the file
  46. f = open(filename, "a")
  47. while sum(model.filter_sizes) > 0:
  48. _, test_acc = util.valid_intent(model, test, criterion, cuda)
  49. print(f"{sum(model.filter_sizes)}, {test_acc:.5f}", file=f, flush=True)
  50. if sum(model.filter_sizes) > 10:
  51. model.prune(5, args.l)
  52. else:
  53. model.prune(1, args.l)
  54. if not args.zeroshot:
  55. optimizer = torch.optim.Adam(model.parameters())
  56. best_epoch = 0
  57. best_valid_loss, _ = util.valid_intent(model, valid, criterion, cuda)
  58. best_model = copy.deepcopy(model)
  59. epoch = 1
  60. while epoch <= best_epoch + args.patience:
  61. train_loss, train_acc = util.train_intent(model, train, criterion, optimizer, cuda)
  62. valid_loss, valid_acc = util.valid_intent(model, valid, criterion, cuda)
  63. if valid_loss < best_valid_loss:
  64. best_valid_loss = valid_loss
  65. best_epoch = epoch
  66. best_model = copy.deepcopy(model)
  67. epoch += 1
  68. model = best_model
  69. elif args.model == 'slot':
  70. open(filename, 'w').close() # clear the file
  71. f = open(filename, "a")
  72. while sum(model.filter_sizes) > 0:
  73. _, test_f1 = util.valid_slot(model, test, criterion, cuda)
  74. print(f"{sum(model.filter_sizes)}, {test_f1:.5f}", file=f, flush=True)
  75. if sum(model.filter_sizes) > 10:
  76. model.prune(5, args.l)
  77. else:
  78. model.prune(1, args.l)
  79. if not args.zeroshot:
  80. optimizer = torch.optim.Adam(model.parameters())
  81. best_epoch = 0
  82. best_valid_loss, _ = util.valid_slot(model, valid, criterion, cuda)
  83. best_model = copy.deepcopy(model)
  84. epoch = 1
  85. while epoch <= best_epoch + args.patience:
  86. train_loss, train_f1 = util.train_slot(model, train, criterion, optimizer, cuda)
  87. valid_loss, valid_f1 = util.valid_slot(model, valid, criterion, cuda)
  88. if valid_loss < best_valid_loss:
  89. best_valid_loss = valid_loss
  90. best_epoch = epoch
  91. best_model = copy.deepcopy(model)
  92. epoch += 1
  93. model = best_model
  94. elif args.model == 'joint':
  95. open(filename, 'w').close() # clear the file
  96. f = open(filename, "a")
  97. while sum(model.filter_sizes) > 0:
  98. _, (_, test_acc), (_, test_f1) = util.valid_joint(model, test, criterion, cuda, args.alpha)
  99. print(f"{sum(model.filter_sizes)}, {test_acc:.5f}, {test_f1:.5f}", file=f, flush=True)
  100. if sum(model.filter_sizes) > 10:
  101. model.prune(5, args.l)
  102. else:
  103. model.prune(1, args.l)
  104. if not args.zeroshot:
  105. optimizer = torch.optim.Adam(model.parameters())
  106. best_epoch = 0
  107. best_valid_loss, (_, _), (_, _) = util.valid_joint(model, valid, criterion, cuda, args.alpha)
  108. best_model = copy.deepcopy(model)
  109. epoch = 1
  110. while epoch <= best_epoch + args.patience:
  111. train_loss, (_, _), (_, _) = util.train_joint(model, train, criterion, optimizer, cuda, args.alpha)
  112. valid_loss, (_, _), (_, _) = util.valid_joint(model, valid, criterion, cuda, args.alpha)
  113. if valid_loss < best_valid_loss:
  114. best_valid_loss = valid_loss
  115. best_epoch = epoch
  116. best_model = copy.deepcopy(model)
  117. epoch += 1
  118. model = best_model