-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathUI.py
More file actions
450 lines (392 loc) · 22 KB
/
UI.py
File metadata and controls
450 lines (392 loc) · 22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
import os
import sys
import logging
import gradio as gr
import sqlite3
from dotenv import load_dotenv
import torch
import chromadb
import gc
# import logger # Assuming this was for a different logger, standard logging is used
import uuid
import datetime
from typing import Optional
# LangChain core and community imports
from langchain_community.vectorstores import Chroma
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain_huggingface import HuggingFaceEmbeddings
# LangChain Google Generative AI integration
from langchain_google_genai import ChatGoogleGenerativeAI
# Importar configuración y retriever personalizado
import config_ui
from custom_retriever import ParallelEnsembleRetriever
from db_logger import log_chat_interaction, SQLiteLoggingCallbackHandler, create_tables, log_user_feedback, get_db_connection
# --- Configuración de Logging ---
# Use standard logging, ensure handler is added if output to console/file is desired
logging.basicConfig(stream=sys.stdout, level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# If you want to use the logger instance from db_logger for UI specific logs:
# ui_logger = logging.getLogger(__name__) # Or use the root logger
# --- Configuración de Dispositivo (CPU/GPU) ---
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
torch.backends.cudnn.benchmark = True
logging.info(f"Usando dispositivo: {device}")
# --- Carga de la Clave API ---
def cargar_api_key():
load_dotenv()
google_api_key = os.getenv(config_ui.GOOGLE_API_KEY_ENV_VAR)
if not google_api_key:
logging.error(f"Error: Variable de entorno {config_ui.GOOGLE_API_KEY_ENV_VAR} no configurada.")
sys.exit(1)
return google_api_key
# --- Carga de la Plantilla de Prompt ---
def cargar_prompt_template_desde_archivo(filepath: str) -> str:
try:
with open(filepath, "r", encoding="utf-8") as f:
prompt_template_str = f.read()
logging.info(f"Plantilla de prompt cargada desde '{filepath}'")
return prompt_template_str
except FileNotFoundError:
logging.error(f"Error: Archivo de plantilla de prompt '{filepath}' no encontrado.")
sys.exit(1)
except Exception as e:
logging.error(f"Error al leer el archivo de plantilla de prompt '{filepath}': {e}", exc_info=True)
sys.exit(1)
# --- Inicialización del Modelo de Embeddings Local ---
def inicializar_modelo_embeddings_local(model_name: str):
logging.info(f"Cargando modelo de embeddings local: {model_name} en dispositivo {device} (para consultas)")
try:
model_kwargs = {'device': device, 'trust_remote_code': True}
encode_kwargs = {'normalize_embeddings': True}
if device == "cuda":
logging.info("Dispositivo CUDA detectado. HuggingFaceEmbeddings intentará optimizar.")
embeddings = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)
logging.info(f"Modelo de embeddings '{model_name}' cargado exitosamente en {device}.")
return embeddings
except Exception as e:
logging.error(f"Error al cargar el modelo de embeddings local '{model_name}': {e}", exc_info=True)
sys.exit(1)
# --- Creación del Parallel Ensemble Retriever ---
def crear_parallel_ensemble_retriever(
embeddings_model: HuggingFaceEmbeddings,
main_persist_directory: str,
biblia_persist_directory: str,
collection_names: list[str],
k_per_collection: int,
max_workers_retriever: int = None
):
retriever_list = []
main_chroma_client = None
if any(name != "biblia" for name in collection_names):
logging.info(f"Intentando conectar a ChromaDB principal en: '{main_persist_directory}'")
if not os.path.exists(main_persist_directory):
logging.error(f"Error: El directorio de persistencia principal no existe: '{main_persist_directory}'.")
return None
try:
main_chroma_client = chromadb.PersistentClient(path=main_persist_directory)
available_collections_main = {col.name for col in main_chroma_client.list_collections()}
main_target_collections = [name for name in collection_names if name != "biblia"]
missing_in_main = [name for name in main_target_collections if name not in available_collections_main]
if missing_in_main:
logging.error(f"Error: Colecciones NO encontradas en '{main_persist_directory}': {missing_in_main}")
return None
except Exception as e:
logging.error(f"Error al inicializar cliente principal de ChromaDB: {e}", exc_info=True)
return None
biblia_chroma_client = None
if "biblia" in collection_names:
logging.info(f"Intentando conectar a ChromaDB para 'biblia' en: '{biblia_persist_directory}'")
if not os.path.exists(biblia_persist_directory):
logging.error(f"Error: Dir de persistencia 'biblia' no existe: '{biblia_persist_directory}'.")
return None
try:
biblia_chroma_client = chromadb.PersistentClient(path=biblia_persist_directory)
available_in_biblia_db = {col.name for col in biblia_chroma_client.list_collections()}
if "biblia" not in available_in_biblia_db: # Assuming internal collection name is also "biblia"
logging.error(f"Error: Colección 'biblia' NO encontrada en '{biblia_persist_directory}'. Disp: {list(available_in_biblia_db)}")
return None
except Exception as e:
logging.error(f"Error al inicializar cliente ChromaDB para 'biblia': {e}", exc_info=True)
return None
logging.info(f"Creando retrievers individuales (k={k_per_collection} por retriever)...")
for collection_name in collection_names:
try:
current_client, current_persist_dir, internal_collection_name = (None, None, None)
if collection_name == "biblia":
if not biblia_chroma_client: continue
current_client, current_persist_dir, internal_collection_name = biblia_chroma_client, biblia_persist_directory, "biblia"
else:
if not main_chroma_client: continue
current_client, current_persist_dir, internal_collection_name = main_chroma_client, main_persist_directory, collection_name
if not current_client: # Should not happen if checks above are fine
logging.warning(f"Cliente no disponible para '{collection_name}', omitiendo retriever.")
continue
vector_store = Chroma(
client=current_client, collection_name=internal_collection_name,
embedding_function=embeddings_model, persist_directory=current_persist_dir
)
retriever = vector_store.as_retriever(search_kwargs={'k': k_per_collection})
retriever_list.append(retriever)
logging.debug(f" -> Retriever para '{collection_name}' creado.")
except Exception as e:
logging.error(f"Error al crear retriever para '{collection_name}': {e}", exc_info=True)
# Potentially return None if one retriever failing is critical
if not retriever_list:
logging.error("Error: No se pudo crear ningún retriever individual.")
return None
return ParallelEnsembleRetriever(retrievers=retriever_list, max_workers=max_workers_retriever)
# --- Configuración del LLM y la Cadena RAG ---
def inicializar_llm(api_key: str, model_name: str, temperature: float):
logging.info(f"Inicializando LLM: {model_name} con temperatura {temperature}")
try:
llm = ChatGoogleGenerativeAI(
model=model_name, google_api_key=api_key,
temperature=temperature, convert_system_message_to_human=True
)
return llm
except Exception as e:
logging.error(f"Error al inicializar el LLM de Gemini ({model_name}): {e}")
sys.exit(1)
def inicializar_cadena_rag(llm, retriever, prompt_template_str: str):
if not llm or not retriever:
logging.error("Faltan componentes (LLM o Retriever) para inicializar la cadena RAG.")
return None
QA_CHAIN_PROMPT = PromptTemplate.from_template(prompt_template_str)
try:
qa_chain = RetrievalQA.from_chain_type(
llm=llm, chain_type="stuff", retriever=retriever,
return_source_documents=True,
chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}
)
logging.info("Cadena RetrievalQA creada exitosamente.")
return qa_chain
except Exception as e:
logging.error(f"Error al crear la cadena RetrievalQA: {e}")
return None
# --- Interfaz de Usuario con Gradio ---
def corregir_acentos(texto: str) -> str:
if not texto: return texto
correcciones = {
" ´a": "á", " ´e": "é", " ´i": "í", " ´o": "ó", " ´u": "ú",
" ´A": "Á", " ´E": "É", " ´I": "Í", " ´O": "Ó", " ´U": "Ú",
"´a": "á", "´e": "é", "´i": "í", "´o": "ó", "´u": "ú",
"´A": "Á", "´E": "É", "´I": "Í", "´O": "Ó", "´U": "Ú",
}
for patron, reemplazo in correcciones.items():
texto = texto.replace(patron, reemplazo)
return texto
def responder_pregunta_con_id(pregunta: str, qa_chain: RetrievalQA, current_request_id: str) -> tuple[str, str, Optional[int]]:
if not pregunta: return "Por favor, introduce una pregunta.", "", None
if not qa_chain: return "Error: La cadena RAG no está inicializada.", "", None
logging.info(f"[Req ID: {current_request_id}] Procesando pregunta: '{pregunta}'")
sql_callback_handler = SQLiteLoggingCallbackHandler(gradio_request_id=current_request_id)
assistant_chat_id = None
try:
respuesta = qa_chain.invoke({"query": pregunta}, config={"callbacks": [sql_callback_handler]})
resultado_bruto = respuesta.get('result', "No se pudo obtener respuesta.")
resultado_corregido = corregir_acentos(resultado_bruto)
fuentes = respuesta.get('source_documents', [])
logging.info(f"[Req ID: {current_request_id}] Respuesta generada: '{resultado_corregido[:100]}...'")
conn = get_db_connection()
try:
cursor = conn.cursor()
timestamp = datetime.datetime.now()
cursor.execute("""
INSERT INTO chat_interactions (gradio_request_id, timestamp, role, content)
VALUES (?, ?, ?, ?)
""", (current_request_id, timestamp, "assistant", resultado_corregido))
assistant_chat_id = cursor.lastrowid
conn.commit()
logging.debug(f"Chat interaction (assistant) logged for {current_request_id}, ID: {assistant_chat_id}")
except sqlite3.Error as e:
logging.error(f"Error al registrar interacción del asistente y obtener ID: {e}", exc_info=True)
finally:
if conn: conn.close()
fuentes_formateadas = []
if fuentes:
fuentes_unicas = set()
for doc in fuentes:
source = os.path.basename(doc.metadata.get('source', 'N/A'))
page = doc.metadata.get('page', -1)
page_display = page + 1 if page != -1 else 'N/A'
identificador = f"Fuente: {source}, Pág: {page_display}"
if identificador not in fuentes_unicas:
fuentes_formateadas.append(f"- {corregir_acentos(identificador)}")
fuentes_unicas.add(identificador)
fuentes_str = "\n".join(fuentes_formateadas) if fuentes_formateadas else "N/A"
return resultado_corregido, fuentes_str, assistant_chat_id
except Exception as e:
logging.error(f"[Req ID: {current_request_id}] Error en RAG: {e}", exc_info=True)
error_msg_for_user = corregir_acentos(f"Ocurrió un error (ID: {current_request_id}): {type(e).__name__}")
error_assistant_chat_id = None
conn_err = get_db_connection()
try:
cursor_err = conn_err.cursor()
timestamp_err = datetime.datetime.now()
cursor_err.execute("""
INSERT INTO chat_interactions (gradio_request_id, timestamp, role, content)
VALUES (?, ?, ?, ?)
""", (current_request_id, timestamp_err, "assistant", f"ERROR: {str(e)}"))
error_assistant_chat_id = cursor_err.lastrowid
conn_err.commit()
except sqlite3.Error as e_db:
logging.error(f"Error al registrar error del asistente: {e_db}", exc_info=True)
finally:
if conn_err: conn_err.close()
return error_msg_for_user, "", error_assistant_chat_id
def crear_interfaz_gradio(qa_chain):
logging.info("Creando la interfaz de Gradio...")
def chatbot_response_wrapper(user_input: str, current_state_value_from_gradio: dict, request: gr.Request):
current_request_id = str(uuid.uuid4())
log_chat_interaction(gradio_request_id=current_request_id, role="user", content=user_input)
respuesta_texto, fuentes_texto, assistant_chat_id = responder_pregunta_con_id(user_input, qa_chain, current_request_id)
# Initialize state for feedback
new_state_to_store = {
"chat_interaction_id": assistant_chat_id,
"gradio_request_id": current_request_id,
"last_feedback_type": None # Reset feedback type for new interaction
}
return (
respuesta_texto,
fuentes_texto,
gr.update(visible=True), # feedback_container
gr.update(value=""), # feedback_comment_textbox (clear for new question)
"", # feedback_status_text (clear for new question)
new_state_to_store # last_interaction_state
)
def handle_feedback(feedback_type_from_button: Optional[str],
comment: str,
current_interaction_state_value: dict):
chat_interaction_id = current_interaction_state_value.get("chat_interaction_id")
gradio_request_id = current_interaction_state_value.get("gradio_request_id")
# Determine the actual feedback type to log
# If a button (like/dislike) was pressed, use its type.
# Otherwise (comment button pressed), use the last recorded feedback type.
actual_feedback_type_to_log = feedback_type_from_button or current_interaction_state_value.get("last_feedback_type")
if not chat_interaction_id or not gradio_request_id:
logging.warning("No hay ID de interacción para registrar feedback.")
return comment, "Error: No se pudo registrar el feedback (sin ID de interacción).", current_interaction_state_value
# It's okay if actual_feedback_type_to_log is None here if the schema allows it (which it now does)
# and the user is only submitting a comment without a prior like/dislike.
log_user_feedback(
chat_interaction_id=chat_interaction_id,
gradio_request_id=gradio_request_id,
feedback_type=actual_feedback_type_to_log, # This can be None
comment=comment # log_user_feedback handles stripping empty strings
)
feedback_message_parts = []
if actual_feedback_type_to_log: # Only mention feedback type if one was actively provided or previously stored
feedback_message_parts.append(f"Valoración ('{actual_feedback_type_to_log}')")
if comment and comment.strip():
feedback_message_parts.append("Comentario guardado.")
if not feedback_message_parts: # e.g. like/dislike without comment, then comment button with empty comment
feedback_message = "Feedback procesado."
else:
feedback_message = "¡Gracias! " + " y ".join(feedback_message_parts) + "."
# Update the state:
# - Store the feedback_type if a like/dislike button was pressed.
# - Keep the existing chat_interaction_id and gradio_request_id.
updated_state = current_interaction_state_value.copy()
if feedback_type_from_button:
updated_state["last_feedback_type"] = feedback_type_from_button
# Decide if comment box should be cleared.
# For "Enviar/Actualizar Comentario", we probably don't want to clear it immediately.
# For like/dislike, maybe clear it or leave as is.
# Current logic clears it only on new question.
# The outputs for handle_feedback in the .click() calls will determine what UI elements update.
# We need to return values for feedback_status_text and last_interaction_state.
# The comment textbox value can also be returned if we want to change it.
return feedback_message, updated_state
with gr.Blocks(theme=gr.themes.Soft()) as interface:
# State stores chat_interaction_id (of assistant's response), gradio_request_id, and last_feedback_type
last_interaction_state = gr.State(
value={"chat_interaction_id": None, "gradio_request_id": None, "last_feedback_type": None}
)
gr.Markdown("# 📄Agente de AI con publicaciones de JW")
gr.Markdown(f"_(Busca en {len(config_ui.TARGET_COLLECTION_NAMES)} colecciones. Concurrencia Gradio: {config_ui.GRADIO_CONCURRENCY_COUNT})_")
with gr.Row():
with gr.Column(scale=2):
pregunta_input = gr.Textbox(lines=4, placeholder="Escribe tu pregunta sobre los documentos...", label="Tu Pregunta")
submit_btn = gr.Button("🔍 Enviar", variant="primary")
with gr.Column(scale=3):
respuesta_output = gr.Textbox(lines=12, label="Respuesta", interactive=False)
fuentes_output = gr.Textbox(lines=5, label="Fuentes (Doc, Pág)", interactive=False)
with gr.Column(visible=False) as feedback_container:
gr.Markdown("#### ¿Te fue útil esta respuesta?")
with gr.Row():
like_btn = gr.Button("👍 Me sirvió", variant="secondary", scale=1)
dislike_btn = gr.Button("👎 No me sirvió", variant="secondary", scale=1)
feedback_comment_textbox = gr.Textbox(
lines=2,
placeholder="Opcional: ¿Por qué o qué añadirías?",
label="Comentario Adicional (pulsa 'Enviar Comentario' o 👍/👎 para guardar)",
interactive=True
)
submit_comment_btn = gr.Button("💬 Enviar/Actualizar Comentario", variant="secondary")
feedback_status_text = gr.Markdown("")
submit_btn.click(
chatbot_response_wrapper,
inputs=[pregunta_input, last_interaction_state],
outputs=[
respuesta_output,
fuentes_output,
feedback_container,
feedback_comment_textbox,
feedback_status_text,
last_interaction_state
]
)
# Like, Dislike, and Submit Comment buttons now all update feedback_status_text and last_interaction_state.
# The comment textbox itself is NOT cleared by these buttons, only by a new question.
like_btn.click(
lambda comment_text, current_state_val: handle_feedback("like", comment_text, current_state_val),
inputs=[feedback_comment_textbox, last_interaction_state],
outputs=[feedback_status_text, last_interaction_state]
)
dislike_btn.click(
lambda comment_text, current_state_val: handle_feedback("dislike", comment_text, current_state_val),
inputs=[feedback_comment_textbox, last_interaction_state],
outputs=[feedback_status_text, last_interaction_state]
)
submit_comment_btn.click(
# Pass None as feedback_type_from_button, handle_feedback will use current_state_val["last_feedback_type"]
lambda comment_text, current_state_val: handle_feedback(None, comment_text, current_state_val),
inputs=[feedback_comment_textbox, last_interaction_state],
outputs=[feedback_status_text, last_interaction_state]
)
logging.info("Interfaz de Gradio creada.")
return interface
# --- Función Principal de Ejecución ---
def main():
# Ensure tables are created on startup (db_logger.py does this on import)
# create_tables() # Already called by import if __name__ != "__main__" in db_logger
google_api_key = cargar_api_key()
prompt_template_str = cargar_prompt_template_desde_archivo(config_ui.PROMPT_TEMPLATE_FILE)
embeddings = inicializar_modelo_embeddings_local(config_ui.EMBEDDING_MODEL_NAME)
parallel_retriever = crear_parallel_ensemble_retriever(
embeddings_model=embeddings,
main_persist_directory=config_ui.CHROMA_PERSIST_DIR,
biblia_persist_directory=config_ui.BIBLIA_CHROMA_PERSIST_DIR,
collection_names=config_ui.TARGET_COLLECTION_NAMES,
k_per_collection=config_ui.SEARCH_K_PER_COLLECTION,
max_workers_retriever=config_ui.RETRIEVER_MAX_WORKERS
)
if not parallel_retriever: sys.exit(1)
llm = inicializar_llm(
api_key=google_api_key, model_name=config_ui.LLM_MODEL_NAME,
temperature=config_ui.TEMPERATURE
)
qa_chain = inicializar_cadena_rag(llm, parallel_retriever, prompt_template_str)
if not qa_chain: sys.exit(1)
interface = crear_interfaz_gradio(qa_chain)
logging.info(f"Lanzando interfaz Gradio con concurrencia: {config_ui.GRADIO_CONCURRENCY_COUNT}")
interface.queue(default_concurrency_limit=config_ui.GRADIO_CONCURRENCY_COUNT).launch(
share=True, server_name="0.0.0.0"
)
if __name__ == "__main__":
main()