You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

94 lines
3.2 KiB

  1. import torch
  2. import torch.nn as nn
  3. import numpy as np
  4. import models
  5. import pickle
  6. import time
  7. start_time = time.time()
  8. PAD = "<pad>"
  9. BOS = "<bos>"
  10. EOS = "<eos>"
  11. word2idx = pickle.load(open("word2idx.pkl", "rb"))
  12. wordvecs = pickle.load(open("wordvecs.pkl", "rb"))
  13. slots = pickle.load(open("slots.pkl", "rb"))
  14. intents = pickle.load(open("intents.pkl", "rb"))
  15. num_words = len(word2idx)
  16. num_intent = 7
  17. num_slot = 72
  18. filter_count = 300
  19. dropout = 0
  20. embedding_dim = 100
  21. def pad_query(sequence):
  22. sequence = [word2idx[BOS]] + sequence + [word2idx[EOS]]
  23. sequence = sequence[:50]
  24. sequence = np.pad(sequence, (0, 50 - len(sequence)), mode='constant', constant_values=(word2idx[PAD],))
  25. return sequence
  26. query = "What's the weather like in York PA right now?"
  27. q = query.lower().replace("'", " ").replace("?", " ").strip()
  28. # true_length = [len(q.split()), 0, 0, 0, 0, 0, 0 ,0]
  29. true_length = [len(q.split())]
  30. qq = torch.from_numpy(pad_query([word2idx[word] for word in q.split()]))
  31. model = models.CNNJoint(num_words, embedding_dim, num_intent, num_slot, (filter_count,), 5, dropout, wordvecs)
  32. model.eval()
  33. model.load_state_dict(torch.load('snips_joint', map_location=torch.device('cpu')))
  34. criterion = torch.nn.CrossEntropyLoss(ignore_index=-1)
  35. pad_tensor = torch.from_numpy(pad_query([word2idx[w] for w in []]))
  36. # batch = torch.stack([qq, pad_tensor, pad_tensor, pad_tensor, pad_tensor, pad_tensor, pad_tensor, pad_tensor])
  37. batch = torch.stack([qq])
  38. pred_intent, pred_slots = model(batch)
  39. slt = [str(item) for batch_num, sublist in enumerate(pred_slots.max(1)[1].tolist()) for item in sublist[1:true_length[batch_num] + 1]]
  40. out_slots = [slots[int(c)] for c in slt]
  41. itnt = pred_intent.max(1)[1].tolist()[0]
  42. out_intent = intents[itnt]
  43. print("Input: {}\nIntent: {}\nSlots: {}".format(query, out_intent, out_slots))
  44. print("--- %s seconds ---" % (time.time() - start_time))
  45. # Write to output file
  46. out = ""
  47. collected_slots = {}
  48. active_slot_words = []
  49. active_slot_name = None
  50. for words, slot_preds, intent_pred in zip([q.split()], [out_slots], [out_intent]):
  51. line = ""
  52. for word, pred in zip(words, slot_preds):
  53. line = line + word + " "
  54. if pred == 'O':
  55. if active_slot_name:
  56. collected_slots[active_slot_name] = " ".join(active_slot_words)
  57. active_slot_words = []
  58. active_slot_name = None
  59. else:
  60. # Naive BIO handling: treat B- and I- the same...
  61. new_slot_name = pred[2:]
  62. if active_slot_name is None:
  63. active_slot_words.append(word)
  64. active_slot_name = new_slot_name
  65. elif new_slot_name == active_slot_name:
  66. active_slot_words.append(word)
  67. else:
  68. collected_slots[active_slot_name] = " ".join(active_slot_words)
  69. active_slot_words = [word]
  70. active_slot_name = new_slot_name
  71. out = line.strip()
  72. if active_slot_name:
  73. collected_slots[active_slot_name] = " ".join(active_slot_words)
  74. print(collected_slots)
  75. print("--- %s seconds ---" % (time.time() - start_time))