1010from app .db .models .dataset_management import DatasetFiles
1111from app .db .models .knowledge_gen import RagFile , RagKnowledgeBase
1212from app .db .models .model_config import ModelConfig
13- from app .db .session import get_db
13+ from app .db .session import get_db , AsyncSessionLocal
1414from app .module .shared .common .document_loaders import load_documents
1515from .graph_rag import (
1616 DEFAULT_WORKING_DIR ,
@@ -59,12 +59,22 @@ async def init_graph_rag(self, knowledge_base_id: str):
5959 kb_working_dir = os .path .join (DEFAULT_WORKING_DIR , kb .name )
6060 self .rag = await initialize_rag (llm_callable , embedding_callable , kb_working_dir )
6161
62+ await self ._schedule_file_processing (knowledge_base_id )
63+
64+ return {"status" : "initialized" , "knowledge_base_id" : knowledge_base_id }
65+
66+ async def _schedule_file_processing (self , knowledge_base_id : str ):
6267 if self .background_tasks is not None :
63- self .background_tasks .add_task (self ._process_pending_files , knowledge_base_id )
68+ self .background_tasks .add_task (self ._process_with_fresh_session , knowledge_base_id , self . rag )
6469 else :
65- asyncio .create_task (self ._process_pending_files (knowledge_base_id ))
70+ asyncio .create_task (self ._process_with_fresh_session (knowledge_base_id , self . rag ))
6671
67- return {"status" : "initialized" , "knowledge_base_id" : knowledge_base_id }
72+ @staticmethod
73+ async def _process_with_fresh_session (knowledge_base_id : str , rag_instance ):
74+ async with AsyncSessionLocal () as session :
75+ service = RAGService (session )
76+ service .rag = rag_instance
77+ await service ._process_pending_files (knowledge_base_id )
6878
6979 async def _process_pending_files (self , knowledge_base_id : str ):
7080 rag_files = await self .get_unprocessed_files (knowledge_base_id )
@@ -77,16 +87,17 @@ async def _process_pending_files(self, knowledge_base_id: str):
7787
7888 async def _process_single_file (self , rag_file : RagFile ):
7989 try :
90+ await self ._mark_file_status (rag_file , "PROCESSING" )
8091 dataset_file = await self ._get_dataset_file (rag_file .file_id )
8192 documents = load_documents (dataset_file .file_path )
8293 for doc in documents :
8394 logger .info (f"Processing document { doc .page_content } " )
8495 await self .rag .ainsert (input = doc .page_content , file_paths = [dataset_file .file_path ])
8596 except Exception : # noqa: BLE001
8697 logger .exception ("Failed to process rag file %s" , rag_file .id )
87- await self ._mark_file_failed (rag_file )
98+ await self ._mark_file_status (rag_file , "PROCESS_FAILED" )
8899 return
89- await self ._mark_file_processed (rag_file )
100+ await self ._mark_file_status (rag_file , "PROCESSED" )
90101
91102 async def _get_dataset_file (self , file_id : str ) -> DatasetFiles :
92103 result = await self .db .execute (
@@ -97,14 +108,8 @@ async def _get_dataset_file(self, file_id: str) -> DatasetFiles:
97108 raise ValueError (f"Dataset file with ID { file_id } not found." )
98109 return dataset_file
99110
100- async def _mark_file_processed (self , rag_file : RagFile ):
101- rag_file .status = "PROCESSED"
102- self .db .add (rag_file )
103- await self .db .commit ()
104- await self .db .refresh (rag_file )
105-
106- async def _mark_file_failed (self , rag_file : RagFile ):
107- rag_file .status = "PROCESS_FAILED"
111+ async def _mark_file_status (self , rag_file : RagFile , status : str ):
112+ rag_file .status = status
108113 self .db .add (rag_file )
109114 await self .db .commit ()
110115 await self .db .refresh (rag_file )
@@ -127,7 +132,6 @@ async def _get_model_config(self, model_id: Optional[str]):
127132 raise ValueError (f"Model config with ID { model_id } not found." )
128133 return model
129134
130-
131135 async def query_rag (self , query : str , knowledge_base_id : str ) -> str :
132136 if not self .rag :
133137 await self .init_graph_rag (knowledge_base_id )
0 commit comments