-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathrecognize.py
More file actions
61 lines (44 loc) · 1.6 KB
/
recognize.py
File metadata and controls
61 lines (44 loc) · 1.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import json
import torch
from torch.autograd import Variable
import crnn
from hwr_utils import character_set, string_utils
import sys
import cv2
import numpy as np
def main():
config_path = sys.argv[1]
image_path = sys.argv[2]
with open(config_path) as f:
config = json.load(f)
idx_to_char, char_to_idx = character_set.load_char_set(config['character_set_path'])
hw = crnn.create_CRNN({
'cnn_out_size': config['network']['cnn_out_size'],
'num_of_channels': 3,
'num_of_outputs': len(idx_to_char)+1
})
hw.load_state_dict(torch.load(config['model_save_path']))
if torch.cuda.is_available():
hw.cuda()
dtype = torch.cuda.FloatTensor
print("Using GPU")
else:
dtype = torch.FloatTensor
print("No GPU detected")
hw.eval()
img = cv2.imread(image_path)
if img.shape[0] != config['network']['input_height']:
percent = float(config['network']['input_height']) / img.shape[0]
img = cv2.resize(img, (0,0), fx=percent, fy=percent, interpolation = cv2.INTER_CUBIC)
img = torch.from_numpy(img.transpose(2,0,1).astype(np.float32)/128 - 1)
img = Variable(img[None,...].type(dtype), requires_grad=False, volatile=True)
preds = hw(img)
output_batch = preds.permute(1,0,2)
out = output_batch.data.cpu().numpy()
pred, pred_raw = string_utils.naive_decode(out[0])
pred_str = string_utils.label2str(pred, idx_to_char, False)
pred_raw_str = string_utils.label2str(pred_raw, idx_to_char, True)
print(pred_raw_str)
print(pred_str)
if __name__ == "__main__":
main()