|
|
- import torch
- import torch.nn as nn
-
- import numpy as np
-
- import models
-
- import pickle
-
- import time
- start_time = time.time()
-
- PAD = "<pad>"
- BOS = "<bos>"
- EOS = "<eos>"
-
- word2idx = pickle.load(open("word2idx.pkl", "rb"))
- wordvecs = pickle.load(open("wordvecs.pkl", "rb"))
- slots = pickle.load(open("slots.pkl", "rb"))
- intents = pickle.load(open("intents.pkl", "rb"))
- num_words = len(word2idx)
- num_intent = 7
- num_slot = 72
- filter_count = 300
- dropout = 0
- embedding_dim = 100
-
-
- def pad_query(sequence):
- sequence = [word2idx[BOS]] + sequence + [word2idx[EOS]]
- sequence = sequence[:50]
- sequence = np.pad(sequence, (0, 50 - len(sequence)), mode='constant', constant_values=(word2idx[PAD],))
- return sequence
-
-
- query = "What's the weather like in York PA right now?"
- q = query.lower().replace("'", " ").replace("?", " ").strip()
- # true_length = [len(q.split()), 0, 0, 0, 0, 0, 0 ,0]
- true_length = [len(q.split())]
- qq = torch.from_numpy(pad_query([word2idx[word] for word in q.split()]))
-
- model = models.CNNJoint(num_words, embedding_dim, num_intent, num_slot, (filter_count,), 5, dropout, wordvecs)
- model.eval()
-
- model.load_state_dict(torch.load('snips_joint', map_location=torch.device('cpu')))
- criterion = torch.nn.CrossEntropyLoss(ignore_index=-1)
-
- pad_tensor = torch.from_numpy(pad_query([word2idx[w] for w in []]))
- # batch = torch.stack([qq, pad_tensor, pad_tensor, pad_tensor, pad_tensor, pad_tensor, pad_tensor, pad_tensor])
- batch = torch.stack([qq])
-
- pred_intent, pred_slots = model(batch)
-
- slt = [str(item) for batch_num, sublist in enumerate(pred_slots.max(1)[1].tolist()) for item in sublist[1:true_length[batch_num] + 1]]
- out_slots = [slots[int(c)] for c in slt]
-
- itnt = pred_intent.max(1)[1].tolist()[0]
- out_intent = intents[itnt]
-
- print("Input: {}\nIntent: {}\nSlots: {}".format(query, out_intent, out_slots))
-
- print("--- %s seconds ---" % (time.time() - start_time))
-
- # Write to output file
- out = ""
- collected_slots = {}
- active_slot_words = []
- active_slot_name = None
- for words, slot_preds, intent_pred in zip([q.split()], [out_slots], [out_intent]):
- line = ""
- for word, pred in zip(words, slot_preds):
- line = line + word + " "
- if pred == 'O':
- if active_slot_name:
- collected_slots[active_slot_name] = " ".join(active_slot_words)
- active_slot_words = []
- active_slot_name = None
- else:
- # Naive BIO handling: treat B- and I- the same...
- new_slot_name = pred[2:]
- if active_slot_name is None:
- active_slot_words.append(word)
- active_slot_name = new_slot_name
- elif new_slot_name == active_slot_name:
- active_slot_words.append(word)
- else:
- collected_slots[active_slot_name] = " ".join(active_slot_words)
- active_slot_words = [word]
- active_slot_name = new_slot_name
- out = line.strip()
- if active_slot_name:
- collected_slots[active_slot_name] = " ".join(active_slot_words)
-
- print(collected_slots)
- print("--- %s seconds ---" % (time.time() - start_time))
|