This repository was archived by the owner on Jan 13, 2020. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathdeploy.py
More file actions
36 lines (29 loc) · 1.27 KB
/
deploy.py
File metadata and controls
36 lines (29 loc) · 1.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import sys
from Helpers import utils
import os
from Models import dmn_basic
import time
def main():
start = time.time()
query = sys.argv[1]
glove = utils.load_glove()
quest = utils.init_babi_deploy(os.path.join(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)),'data.bak'),'corpus'),'babi1.txt'),query)
dmn = dmn_basic.DMN_basic(babi_train_raw=quest,babi_test_raw=[],word2vec=glove,word_vector_size=50,dim=40,mode='deploy',answer_module='feedforward', input_mask_mode="sentence", memory_hops=5, l2=0,
normalize_attention=False, answer_vec='index', debug=False)
dmn.load_state('states/dmn_basic/dmn_basic.mh5.n40.bs10.babi1.epoch2.test1.20454.state')
# dmn.load_state('states/dmn_basic/dmn_basic.mh5.n40.bs10.babi1.epoch0.test1.48296.state')
prediction = dmn.step_deploy()
prediction = prediction[0][0]
for ind in prediction.argsort()[::-1]:
if ind < dmn.answer_size:
print(dmn.ivocab[ind], prediction[ind])
# break
print('Time taken:',time.time()-start)
# print(len(dmn.ivocab))
# print(len(dmn.vocab))
# print(dmn.answer_size)
# print(prediction.argmax())
# print(len(prediction[0][0]))
# print(prediction.shape)
if __name__ == '__main__':
main()