Browse Source

Fixed input size thingy

mistress
Daniel Muckerman 4 years ago
parent
commit
44e19df4fc
1 changed files with 38 additions and 3 deletions
  1. +38
    -3
      test_query.py

+ 38
- 3
test_query.py View File

@ -33,9 +33,10 @@ def pad_query(sequence):
return sequence
query = "What's the weather like in Great Mills right now?"
query = "What's the weather like in York PA right now?"
q = query.lower().replace("'", " ").replace("?", " ").strip()
true_length = [len(q.split()), 0, 0, 0, 0, 0, 0 ,0]
# true_length = [len(q.split()), 0, 0, 0, 0, 0, 0 ,0]
true_length = [len(q.split())]
qq = torch.from_numpy(pad_query([word2idx[word] for word in q.split()]))
model = models.CNNJoint(num_words, embedding_dim, num_intent, num_slot, (filter_count,), 5, dropout, wordvecs)
@ -45,7 +46,8 @@ model.load_state_dict(torch.load('snips_joint', map_location=torch.device('cpu')
criterion = torch.nn.CrossEntropyLoss(ignore_index=-1)
pad_tensor = torch.from_numpy(pad_query([word2idx[w] for w in []]))
batch = torch.stack([qq, pad_tensor, pad_tensor, pad_tensor, pad_tensor, pad_tensor, pad_tensor, pad_tensor])
# batch = torch.stack([qq, pad_tensor, pad_tensor, pad_tensor, pad_tensor, pad_tensor, pad_tensor, pad_tensor])
batch = torch.stack([qq])
pred_intent, pred_slots = model(batch)
@ -57,4 +59,37 @@ out_intent = intents[itnt]
print("Input: {}\nIntent: {}\nSlots: {}".format(query, out_intent, out_slots))
print("--- %s seconds ---" % (time.time() - start_time))
# Write to output file
out = ""
collected_slots = {}
active_slot_words = []
active_slot_name = None
for words, slot_preds, intent_pred in zip([q.split()], [out_slots], [out_intent]):
line = ""
for word, pred in zip(words, slot_preds):
line = line + word + " "
if pred == 'O':
if active_slot_name:
collected_slots[active_slot_name] = " ".join(active_slot_words)
active_slot_words = []
active_slot_name = None
else:
# Naive BIO handling: treat B- and I- the same...
new_slot_name = pred[2:]
if active_slot_name is None:
active_slot_words.append(word)
active_slot_name = new_slot_name
elif new_slot_name == active_slot_name:
active_slot_words.append(word)
else:
collected_slots[active_slot_name] = " ".join(active_slot_words)
active_slot_words = [word]
active_slot_name = new_slot_name
out = line.strip()
if active_slot_name:
collected_slots[active_slot_name] = " ".join(active_slot_words)
print(collected_slots)
print("--- %s seconds ---" % (time.time() - start_time))

Loading…
Cancel
Save