Skip to content

Commit 45b634e

Browse files
committed
feat: enhance RAG service with async file processing and improve logging
1 parent fbd7157 commit 45b634e

File tree

2 files changed

+20
-16
lines changed

2 files changed

+20
-16
lines changed

runtime/datamate-python/app/module/rag/service/graph_rag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from lightrag.llm.openai import openai_embed, openai_complete_if_cache
99
from lightrag.utils import setup_logger, EmbeddingFunc, get_env_value
1010

11-
setup_logger("lightrag", level="DEBUG")
11+
setup_logger("lightrag", level="INFO")
1212
DEFAULT_WORKING_DIR = os.path.join(os.getcwd(), "rag_storage")
1313

1414

runtime/datamate-python/app/module/rag/service/rag_service.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from app.db.models.dataset_management import DatasetFiles
1111
from app.db.models.knowledge_gen import RagFile, RagKnowledgeBase
1212
from app.db.models.model_config import ModelConfig
13-
from app.db.session import get_db
13+
from app.db.session import get_db, AsyncSessionLocal
1414
from app.module.shared.common.document_loaders import load_documents
1515
from .graph_rag import (
1616
DEFAULT_WORKING_DIR,
@@ -59,12 +59,22 @@ async def init_graph_rag(self, knowledge_base_id: str):
5959
kb_working_dir = os.path.join(DEFAULT_WORKING_DIR, kb.name)
6060
self.rag = await initialize_rag(llm_callable, embedding_callable, kb_working_dir)
6161

62+
await self._schedule_file_processing(knowledge_base_id)
63+
64+
return {"status": "initialized", "knowledge_base_id": knowledge_base_id}
65+
66+
async def _schedule_file_processing(self, knowledge_base_id: str):
6267
if self.background_tasks is not None:
63-
self.background_tasks.add_task(self._process_pending_files, knowledge_base_id)
68+
self.background_tasks.add_task(self._process_with_fresh_session, knowledge_base_id, self.rag)
6469
else:
65-
asyncio.create_task(self._process_pending_files(knowledge_base_id))
70+
asyncio.create_task(self._process_with_fresh_session(knowledge_base_id, self.rag))
6671

67-
return {"status": "initialized", "knowledge_base_id": knowledge_base_id}
72+
@staticmethod
73+
async def _process_with_fresh_session(knowledge_base_id: str, rag_instance):
74+
async with AsyncSessionLocal() as session:
75+
service = RAGService(session)
76+
service.rag = rag_instance
77+
await service._process_pending_files(knowledge_base_id)
6878

6979
async def _process_pending_files(self, knowledge_base_id: str):
7080
rag_files = await self.get_unprocessed_files(knowledge_base_id)
@@ -77,16 +87,17 @@ async def _process_pending_files(self, knowledge_base_id: str):
7787

7888
async def _process_single_file(self, rag_file: RagFile):
7989
try:
90+
await self._mark_file_status(rag_file, "PROCESSING")
8091
dataset_file = await self._get_dataset_file(rag_file.file_id)
8192
documents = load_documents(dataset_file.file_path)
8293
for doc in documents:
8394
logger.info(f"Processing document {doc.page_content}")
8495
await self.rag.ainsert(input=doc.page_content, file_paths=[dataset_file.file_path])
8596
except Exception: # noqa: BLE001
8697
logger.exception("Failed to process rag file %s", rag_file.id)
87-
await self._mark_file_failed(rag_file)
98+
await self._mark_file_status(rag_file, "PROCESS_FAILED")
8899
return
89-
await self._mark_file_processed(rag_file)
100+
await self._mark_file_status(rag_file, "PROCESSED")
90101

91102
async def _get_dataset_file(self, file_id: str) -> DatasetFiles:
92103
result = await self.db.execute(
@@ -97,14 +108,8 @@ async def _get_dataset_file(self, file_id: str) -> DatasetFiles:
97108
raise ValueError(f"Dataset file with ID {file_id} not found.")
98109
return dataset_file
99110

100-
async def _mark_file_processed(self, rag_file: RagFile):
101-
rag_file.status = "PROCESSED"
102-
self.db.add(rag_file)
103-
await self.db.commit()
104-
await self.db.refresh(rag_file)
105-
106-
async def _mark_file_failed(self, rag_file: RagFile):
107-
rag_file.status = "PROCESS_FAILED"
111+
async def _mark_file_status(self, rag_file: RagFile, status: str):
112+
rag_file.status = status
108113
self.db.add(rag_file)
109114
await self.db.commit()
110115
await self.db.refresh(rag_file)
@@ -127,7 +132,6 @@ async def _get_model_config(self, model_id: Optional[str]):
127132
raise ValueError(f"Model config with ID {model_id} not found.")
128133
return model
129134

130-
131135
async def query_rag(self, query: str, knowledge_base_id: str) -> str:
132136
if not self.rag:
133137
await self.init_graph_rag(knowledge_base_id)

0 commit comments

Comments
 (0)