import torch import torch.nn as nn import numpy as np import models import pickle import time start_time = time.time() PAD = "" BOS = "" EOS = "" word2idx = pickle.load(open("word2idx.pkl", "rb")) wordvecs = pickle.load(open("wordvecs.pkl", "rb")) slots = pickle.load(open("slots.pkl", "rb")) slot_filters = pickle.load(open("slot_filters.pkl", "rb")) intents = pickle.load(open("intents.pkl", "rb")) num_words = len(word2idx) num_intent = len(intents) num_slot = len(slots) filter_count = 300 dropout = 0 embedding_dim = 100 def pad_query(sequence): sequence = [word2idx[BOS]] + sequence + [word2idx[EOS]] sequence = sequence[:50] sequence = np.pad(sequence, (0, 50 - len(sequence)), mode='constant', constant_values=(word2idx[PAD],)) return sequence def predict(query): q = query.lower().replace("'", " ").replace("?", " ").strip() true_length = [len(q.split())] qq = torch.from_numpy(pad_query([word2idx[word] if word in word2idx else word2idx[""] for word in q.split()])) model = models.CNNJoint(num_words, embedding_dim, num_intent, num_slot, (filter_count,), 5, dropout, wordvecs) model.eval() model.load_state_dict(torch.load('snips_joint', map_location=torch.device('cpu'))) batch = torch.stack([qq]) pred_intent, pred_slots = model(batch) itnt = pred_intent.max(1)[1].tolist()[0] out_intent = intents[itnt] if out_intent in slot_filters: b = [1 if x in slot_filters[out_intent] else 0 for x in slots] zz = torch.stack([torch.FloatTensor([b]).repeat(50,1).transpose(0,1)]) pred_slots = torch.mul(pred_slots, zz) slt = [str(item) for batch_num, sublist in enumerate(pred_slots.max(1)[1].tolist()) for item in sublist[1:true_length[batch_num] + 1]] out_slots = [slots[int(c)] for c in slt] print("Input: {}\nIntent: {}\nSlots: {}".format(query, out_intent, out_slots)) print("--- %s seconds ---" % (time.time() - start_time)) # Write to output file out = "" collected_slots = {} active_slot_words = [] active_slot_name = None for words, slot_preds, intent_pred in zip([q.split()], [out_slots], [out_intent]): line = "" for word, pred in zip(words, slot_preds): line = line + word + " " if pred == 'O': if active_slot_name: collected_slots[active_slot_name] = " ".join(active_slot_words) active_slot_words = [] active_slot_name = None else: # Naive BIO handling: treat B- and I- the same... new_slot_name = pred[2:] if active_slot_name is None: active_slot_words.append(word) active_slot_name = new_slot_name elif new_slot_name == active_slot_name: active_slot_words.append(word) else: collected_slots[active_slot_name] = " ".join(active_slot_words) active_slot_words = [word] active_slot_name = new_slot_name out = line.strip() if active_slot_name: collected_slots[active_slot_name] = " ".join(active_slot_words) print(collected_slots) print("--- %s seconds ---" % (time.time() - start_time)) predict("What's the weather like in York PA right now?") predict("How's the weather in York PA right now?") predict("What's the weather like in Great Mills right now?") predict("What will the weather be like in Frederick Maryland tomorrow?") predict("Play some jazz") predict("Play some daft punk") predict("Play some hatsune miku")