Skip to content

Commit d8e357e

Browse files
hhhhsc701hefanli
andauthored
fix: 修复paddleocr模型报错(#239)
* refactor: replace the database from mysql to pgsql * refactor: replace the database from mysql to pgsql * refactor: merge the databases of DataMate and LabelStudio * refactor: merge the databases of DataMate and LabelStudio * fix: resolve the conflict * feat: 适配pgsql * fix: resolve the annotation task bug * fix: fix the system param presetting data * fix: 修复paddleocr模型报错 * fix: 修复paddleocr模型报错 * fix: 修复paddleocr模型报错 --------- Co-authored-by: uname <2986773479@qq.com>
1 parent 7ae3559 commit d8e357e

File tree

5 files changed

+21
-40
lines changed

5 files changed

+21
-40
lines changed

runtime/ops/filter/img_blurred_images_cleaner/process.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def execute(self, sample: Dict[str, Any]):
3535
data = bytes_transform.bytes_to_numpy(img_bytes)
3636
blurred_images = self._blurred_images_filter(data, file_name)
3737
sample[self.data_key] = bytes_transform.numpy_to_bytes(blurred_images, file_type)
38-
logger.info(f"fileName: file_name, method: ImagesBlurredCleaner costs {(time.time() - start):6f} s")
38+
logger.info(f"fileName: {file_name}, method: ImagesBlurredCleaner costs {(time.time() - start):6f} s")
3939
return sample
4040

4141
def _blurred_images_filter(self, image, file_name):
@@ -46,6 +46,6 @@ def _blurred_images_filter(self, image, file_name):
4646
score = cv2.Laplacian(gray, cv2.CV_64F).var()
4747
if score <= self._blurred_threshold:
4848
logger.info(f"The image blur is {self._blurred_threshold}, "
49-
f"which exceeds the threshold of score}. {file_name is filtered out.")
49+
f"which exceeds the threshold of {score}. {file_name} is filtered out.")
5050
return np.array([])
5151
return image

runtime/ops/filter/img_similar_images_cleaner/sql/sql_config.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
"query_sql": "SELECT * FROM operator_similar_img_features WHERE task_uuid = :task_uuid ORDER BY timestamp LIMIT :ge OFFSET :le",
33
"insert_sql": "INSERT INTO operator_similar_img_features (task_uuid,p_hash,des_matrix,matrix_shape,file_name,timestamp) VALUES (:task_uuid,:p_hash,:des_matrix,:matrix_shape,:file_name,:timestamp)",
44
"query_task_uuid_sql": "SELECT * FROM operator_similar_img_features WHERE task_uuid = :task_uuid",
5-
"create_tables_sql": "CREATE TABLE IF NOT EXISTS operator_similar_img_features (id SERIAL PRIMARY KEY,task_uuid VARCHAR(255),p_hash TEXT,des_matrix BLOB,matrix_shape TEXT,file_name TEXT,timestamp TIMESTAMP);"
5+
"create_tables_sql": "CREATE TABLE IF NOT EXISTS operator_similar_img_features (id SERIAL PRIMARY KEY,task_uuid VARCHAR(255),p_hash TEXT,des_matrix BYTEA,matrix_shape TEXT,file_name TEXT,timestamp TIMESTAMP);"
66
}

runtime/ops/mapper/img_direction_correct/base_model.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,34 +4,15 @@
44
import os
55
from pathlib import Path
66

7-
from argparse import Namespace
8-
97

108
class BaseModel:
119

12-
def __init__(self, model_type='vertical'):
10+
def __init__(self, *args, **kwargs):
1311
models_path = os.getenv("MODELS_PATH", "/home/models")
14-
args = Namespace()
15-
args.cls_image_shape = '3, 224, 224'
16-
args.cls_batch_num = 6
17-
args.cls_thresh = 0.9
18-
args.use_onnx = False
19-
args.use_gpu = False
20-
args.use_npu = False
21-
args.use_xpu = False
22-
args.use_mlu = False
23-
args.enable_mkldnn = False
24-
if model_type == 'vertical':
25-
args.cls_model_dir = str(Path(models_path, 'ch_ppocr_mobile_v2.0_cls_infer'))
26-
self.model_name = 'standard model to detect image 0 or 90 rotated'
27-
args.label_list = ['0', '90']
28-
else:
29-
args.cls_model_dir = str(Path(models_path, 'ch_ppocr_mobile_v2.0_cls_infer'))
30-
self.model_name = 'standard model to detect image 0 or 180 rotated'
31-
args.label_list = ['0', '180']
12+
model_dir = str(Path(models_path, 'PP-LCNet_x1_0_doc_ori_infer'))
3213

33-
from paddleocr.tools.infer.predict_cls import TextClassifier
34-
self.infer = TextClassifier(args)
14+
from paddleocr import DocImgOrientationClassification
15+
self.infer = DocImgOrientationClassification(model_dir=model_dir)
3516

3617
def __del__(self):
3718
del self.infer

runtime/ops/mapper/img_direction_correct/process.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self, *args, **kwargs):
2424
self.img_resize = 1000
2525
self.limit_size = 30000
2626
self.use_model = True
27-
self.vertical_model, self.standard_model = self.get_model(*args, **kwargs)
27+
self.model = self.get_model(*args, **kwargs)
2828

2929
@staticmethod
3030
def _detect_angle(img):
@@ -60,15 +60,17 @@ def _detect_direction(image, file_name, model):
6060
Returns: 旋转后的图片
6161
"""
6262
# cls_res为模型预测结果,格式应当类似于: [('90', 0.9815167)]
63-
_, cls_res, _ = model.infer([image])
64-
rotate_angle = int(cls_res[0][0])
65-
pro = float(cls_res[0][1])
63+
cls_res = model.infer.predict([image])[0]
64+
rotate_angle = int(cls_res.get("class_ids", np.array([0], dtype='int32')).item())
65+
pro = float(cls_res.get("scores", np.array([0], dtype='int32')).item())
6666
logger.info(
6767
f"fileName: {file_name}, model {model.model_name} detect result is {rotate_angle} with confidence {pro}")
6868
if rotate_angle == 90 and pro > 0.89:
6969
return cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
7070
if rotate_angle == 180 and pro > 0.89:
71-
return cv2.rotate(image, 1)
71+
return cv2.rotate(image, cv2.ROTATE_180)
72+
if rotate_angle == 270 and pro > 0.89:
73+
return cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
7274
return image
7375

7476
@staticmethod
@@ -93,7 +95,7 @@ def _rotate_bound(image, angle):
9395
return dst_img
9496

9597
def init_model(self, *args, **kwargs):
96-
return BaseModel(model_type='vertical'), BaseModel(model_type='standard')
98+
return BaseModel(*args, **kwargs)
9799

98100
def execute(self, sample: Dict[str, Any]):
99101
start = time.time()
@@ -103,12 +105,12 @@ def execute(self, sample: Dict[str, Any]):
103105
img_bytes = sample[self.data_key]
104106
if img_bytes:
105107
data = bytes_transform.bytes_to_numpy(img_bytes)
106-
correct_data = self._img_direction_correct(data, file_name, self.vertical_model, self.standard_model)
108+
correct_data = self._img_direction_correct(data, file_name, self.model)
107109
sample[self.data_key] = bytes_transform.numpy_to_bytes(correct_data, file_type)
108110
logger.info(f"fileName: {file_name}, method: ImgDirectionCorrect costs {time.time() - start:6f} s")
109111
return sample
110112

111-
def _img_direction_correct(self, img, file_name, vertical_model, standard_model):
113+
def _img_direction_correct(self, img, file_name, standard_model):
112114
height, width = img.shape[:2]
113115
if max(height, width) > self.limit_size:
114116
logger.info(
@@ -119,8 +121,6 @@ def _img_direction_correct(self, img, file_name, vertical_model, standard_model)
119121
angle = self._detect_angle(detect_angle_img)
120122
# 将图片处理为 0, 90, 180, 270旋转角度的图片
121123
rotated_img = self._rotate_bound(img, angle)
122-
# 水平垂直方向识别:二分类模型,检测图片方向角为 0, 90, 将其处理为 0和180二分类图片
123-
rotated_img = self._detect_direction(rotated_img, file_name, vertical_model)
124124
# 0-180方向识别:二分类模型,检测图片方向角为 0, 180, 将其处理为 0和180二分类图片
125125
rotated_img = self._detect_direction(rotated_img, file_name, standard_model)
126126
return rotated_img

scripts/images/runtime/Dockerfile

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ RUN --mount=type=cache,target=/var/cache/apt \
66
&& apt install -y libgl1 libglib2.0-0 vim libmagic1 libreoffice dos2unix swig poppler-utils tesseract-ocr
77

88
RUN mkdir -p /home/models \
9-
&& wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar \
10-
&& tar -xf ch_ppocr_mobile_v2.0_cls_infer.tar -C /home/models \
11-
&& rm -f ch_*.tar
9+
&& wget https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0.0/PP-LCNet_x1_0_doc_ori_infer.tar \
10+
&& tar -xf PP-LCNet_x1_0_doc_ori_infer.tar -C /home/models \
11+
&& rm -f PP_*.tar
1212

1313
COPY runtime/python-executor /opt/runtime
1414
COPY runtime/ops /opt/runtime/datamate/ops
@@ -22,7 +22,7 @@ ENV UV_INDEX_STRATEGY=unsafe-best-match
2222
WORKDIR /opt/runtime
2323

2424
RUN --mount=type=cache,target=/root/.cache/uv \
25-
uv pip install -e .[all] --system \
25+
uv pip install -e . --system \
2626
&& uv pip install -r /opt/runtime/datamate/ops/pyproject.toml --system \
2727
&& python -m spacy download zh_core_web_sm \
2828
&& python -c "import nltk; nltk.download('punkt_tab'); nltk.download('averaged_perceptron_tagger_eng')" \

0 commit comments

Comments
 (0)