-
-
Notifications
You must be signed in to change notification settings - Fork 38
Expand file tree
/
Copy pathtypings.py
More file actions
77 lines (60 loc) · 2.06 KB
/
typings.py
File metadata and controls
77 lines (60 loc) · 2.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: [email protected]
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
from .utils import mkdir
from .vis import VisTable
class EngineType(Enum):
ONNXRUNTIME = "onnxruntime"
TORCH = "torch"
class ModelType(Enum):
PPSTRUCTURE_EN = "ppstructure_en"
PPSTRUCTURE_ZH = "ppstructure_zh"
SLANETPLUS = "slanet_plus"
UNITABLE = "unitable"
@dataclass
class RapidTableInput:
model_type: Optional[ModelType] = ModelType.SLANETPLUS
model_dir_or_path: Union[str, Path, None, Dict[str, str]] = None
engine_type: Optional[EngineType] = None
engine_cfg: dict = field(default_factory=dict)
use_ocr: bool = True
ocr_params: dict = field(default_factory=dict)
@dataclass
class RapidTableOutput:
imgs: List[np.ndarray] = field(default_factory=list)
pred_htmls: List[str] = field(default_factory=list)
cell_bboxes: List[np.ndarray] = field(default_factory=list)
logic_points: List[np.ndarray] = field(default_factory=list)
elapse: float = 0.0
def vis(
self,
save_dir: Union[str, Path],
save_name: str,
indexes: Tuple[int, ...] = (0,),
) -> List[np.ndarray]:
vis = VisTable()
save_dir = Path(save_dir)
mkdir(save_dir)
results = []
for idx in indexes:
save_one_dir = save_dir / str(idx)
mkdir(save_one_dir)
save_html_path = save_one_dir / f"{save_name}.html"
save_drawed_path = save_one_dir / f"{save_name}_vis.jpg"
save_logic_points_path = save_one_dir / f"{save_name}_col_row_vis.jpg"
vis_img = vis(
self.imgs[idx],
self.pred_htmls[idx],
self.cell_bboxes[idx],
self.logic_points[idx],
save_html_path,
save_drawed_path,
save_logic_points_path,
)
results.append(vis_img)
return results