You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

133 lines
5.3 KiB

4 years ago
  1. import collections
  2. import torch
  3. from torch.utils.data import Dataset, DataLoader
  4. import numpy as np
  5. PAD = "<pad>"
  6. BOS = "<bos>"
  7. EOS = "<eos>"
  8. def build_glove(word2idx, idx2word, dim=100):
  9. word2vecs = {}
  10. with open(f'glove/glove.6B.{dim}d.txt') as glove_file:
  11. for i, line in enumerate(glove_file):
  12. splat = line.split()
  13. word = str(splat.pop(0))
  14. if word in word2idx:
  15. word2vecs[word] = np.array(splat).astype(float)
  16. vectors = []
  17. for word in [idx2word[i] for i in range(len(idx2word))]:
  18. if word in word2vecs:
  19. vectors.append(torch.from_numpy(word2vecs[word]).float())
  20. else:
  21. vectors.append(torch.from_numpy(np.random.normal(0, 0.5, size=(dim,))).float())
  22. return torch.stack(vectors)
  23. class Corpus(Dataset):
  24. def __init__(self, dataset, split_name, seq_len: int):
  25. self.seq_len = seq_len
  26. self.queries = []
  27. self.intents = []
  28. self.slots = []
  29. self.word2idx = {}
  30. self.idx2word = {}
  31. self.intent2idx = {}
  32. self.slot2idx = {}
  33. self._register(PAD)
  34. self._register(BOS)
  35. self._register(EOS)
  36. for split in ['train', 'valid', 'test']:
  37. with open(f'datasets/{dataset}/{split}/label') as intent_file:
  38. for line in intent_file:
  39. intent = line.rstrip()
  40. if intent not in self.intent2idx:
  41. self.intent2idx[intent] = len(self.intent2idx)
  42. with open(f'datasets/{dataset}/{split}/seq.in') as queries_file:
  43. for line in queries_file:
  44. query = line.rstrip().split()
  45. for word in query:
  46. if word not in self.word2idx:
  47. idx = len(self.word2idx)
  48. self.word2idx[word] = idx
  49. self.idx2word[idx] = word
  50. with open(f'datasets/{dataset}/{split}/seq.out') as slotses_file:
  51. for line in slotses_file:
  52. slots = line.rstrip().split()
  53. for slot in slots:
  54. if slot not in self.slot2idx:
  55. self.slot2idx[slot] = len(self.slot2idx)
  56. with open(f'datasets/{dataset}/{split_name}/label') as intent_file:
  57. for line in intent_file:
  58. intent = line.rstrip()
  59. self.intents.append(intent)
  60. with open(f'datasets/{dataset}/{split_name}/seq.in') as queries_file:
  61. for line in queries_file:
  62. query = line.rstrip().split()
  63. self.queries.append(query)
  64. with open(f'datasets/{dataset}/{split_name}/seq.out') as slotses_file:
  65. for line in slotses_file:
  66. slots = line.rstrip().split()
  67. self.slots.append(slots)
  68. self.idx2intent = {v: k for k, v in self.intent2idx.items()}
  69. self.idx2slot = {v : k for k, v in self.slot2idx.items()}
  70. def _register(self, word):
  71. if word in self.word2idx:
  72. return
  73. assert(len(self.idx2word) == len(self.word2idx))
  74. idx = len(self.idx2word)
  75. self.idx2word[idx] = word
  76. self.word2idx[word] = idx
  77. def pad_query(self, sequence):
  78. sequence = [self.word2idx[BOS]] + sequence + [self.word2idx[EOS]]
  79. sequence = sequence[:self.seq_len]
  80. sequence = np.pad(sequence, (0, self.seq_len - len(sequence)), mode='constant', constant_values=(self.word2idx[PAD],))
  81. return sequence
  82. def pad_slots(self, sequence):
  83. sequence = [-1] + sequence + [-1]
  84. sequence = sequence[:self.seq_len]
  85. sequence = np.pad(sequence, (0, self.seq_len - len(sequence)), mode='constant', constant_values=(-1,))
  86. return sequence
  87. def __getitem__(self, i):
  88. query = torch.from_numpy(self.pad_query([self.word2idx[word] for word in self.queries[i]]))
  89. intent = torch.tensor(self.intent2idx[self.intents[i]])
  90. slots = torch.from_numpy(self.pad_slots([self.slot2idx[slot] for slot in self.slots[i]]))
  91. true_length = torch.tensor(min(len(self.queries[i]), self.seq_len))
  92. return query, intent, slots, true_length, (self.queries[i], self.intents[i], self.slots[i]), (self.idx2word, self.idx2intent, self.idx2slot)
  93. def __len__(self):
  94. assert(len(self.queries) == len(self.intents))
  95. return len(self.queries)
  96. def load(dataset, batch_size, seq_len):
  97. train_corpus, valid_corpus, test_corpus = Corpus(dataset, 'train', seq_len), Corpus(dataset, 'valid', seq_len), Corpus(dataset, 'test', seq_len)
  98. # sanity checks
  99. assert(len(train_corpus.word2idx) == len(valid_corpus.word2idx) == len(test_corpus.word2idx))
  100. assert(len(train_corpus.intent2idx) == len(valid_corpus.intent2idx) == len(test_corpus.intent2idx))
  101. assert(len(train_corpus.slot2idx) == len(valid_corpus.slot2idx) == len(test_corpus.slot2idx))
  102. num_words, num_intents, num_slots = len(train_corpus.word2idx), len(train_corpus.intent2idx), len(train_corpus.slot2idx)
  103. wordvecs = build_glove(train_corpus.word2idx, train_corpus.idx2word)
  104. return (DataLoader(train_corpus, batch_size, shuffle=True),
  105. DataLoader(valid_corpus, batch_size, shuffle=False),
  106. DataLoader(test_corpus, batch_size),
  107. num_words, num_intents, num_slots, wordvecs)