@@ -89,23 +89,23 @@ def init(backend, **kwargs):
8989async def startup_event ():
9090 """Event that is triggered when the application starts up"""
9191 global startup_event_triggered
92-
92+
9393 # Only log once
9494 if startup_event_triggered :
9595 return
96-
96+
9797 logger .info ("FastAPI application startup event triggered" )
9898 startup_event_triggered = True
99-
99+
100100 # Wait a short time to ensure logs are processed
101101 await asyncio .sleep (0.5 )
102-
102+
103103 # Log a special message that our callback handler will detect
104104 root_logger = logging .getLogger ()
105105 root_logger .info ("Application startup complete - banner display trigger" )
106-
106+
107107 logger .info (f"{ Fore .CYAN } Starting LocalLab server...{ Style .RESET_ALL } " )
108-
108+
109109 # Get HuggingFace token and set it in environment if available
110110 from ..config import get_hf_token
111111 hf_token = get_hf_token (interactive = False )
@@ -114,43 +114,43 @@ async def startup_event():
114114 logger .info (f"{ Fore .GREEN } HuggingFace token loaded from configuration{ Style .RESET_ALL } " )
115115 else :
116116 logger .warning (f"{ Fore .YELLOW } No HuggingFace token found. Some models may not be accessible.{ Style .RESET_ALL } " )
117-
117+
118118 # Check if ngrok should be enabled
119119 from ..cli .config import get_config_value
120120 use_ngrok = get_config_value ("use_ngrok" , False )
121121 if use_ngrok :
122122 from ..utils .networking import setup_ngrok
123123 port = int (os .environ .get ("LOCALLAB_PORT" , SERVER_PORT )) # Use SERVER_PORT as fallback
124-
124+
125125 # Handle ngrok setup synchronously since it's not async
126126 ngrok_url = setup_ngrok (port )
127127 if ngrok_url :
128128 logger .info (f"{ Fore .GREEN } Ngrok tunnel established successfully{ Style .RESET_ALL } " )
129129 else :
130130 logger .warning ("Failed to establish ngrok tunnel. Server will run locally only." )
131-
131+
132132 # Initialize cache if available
133133 if FASTAPI_CACHE_AVAILABLE :
134134 FastAPICache .init (InMemoryBackend (), prefix = "locallab-cache" )
135135 logger .info ("FastAPICache initialized" )
136136 else :
137137 logger .warning ("FastAPICache not available, caching disabled" )
138-
138+
139139 # Check for model specified in environment variables or CLI config
140140 model_to_load = (
141- os .environ .get ("HUGGINGFACE_MODEL" ) or
142- get_config_value ("model_id" ) or
141+ os .environ .get ("HUGGINGFACE_MODEL" ) or
142+ get_config_value ("model_id" ) or
143143 DEFAULT_MODEL
144144 )
145-
145+
146146 # Log model configuration
147147 logger .info (f"{ Fore .CYAN } Model configuration:{ Style .RESET_ALL } " )
148148 logger .info (f" - Model to load: { model_to_load } " )
149149 logger .info (f" - Quantization: { 'Enabled - ' + os .environ .get ('LOCALLAB_QUANTIZATION_TYPE' , QUANTIZATION_TYPE ) if os .environ .get ('LOCALLAB_ENABLE_QUANTIZATION' , '' ).lower () == 'true' else 'Disabled' } " )
150150 logger .info (f" - Attention slicing: { 'Enabled' if os .environ .get ('LOCALLAB_ENABLE_ATTENTION_SLICING' , '' ).lower () == 'true' else 'Disabled' } " )
151151 logger .info (f" - Flash attention: { 'Enabled' if os .environ .get ('LOCALLAB_ENABLE_FLASH_ATTENTION' , '' ).lower () == 'true' else 'Disabled' } " )
152152 logger .info (f" - Better transformer: { 'Enabled' if os .environ .get ('LOCALLAB_ENABLE_BETTERTRANSFORMER' , '' ).lower () == 'true' else 'Disabled' } " )
153-
153+
154154 # Start loading the model in background if specified
155155 if model_to_load :
156156 try :
@@ -166,66 +166,89 @@ async def startup_event():
166166async def shutdown_event ():
167167 """Cleanup tasks when the server shuts down"""
168168 logger .info (f"{ Fore .YELLOW } Shutting down server...{ Style .RESET_ALL } " )
169-
169+
170170 # Unload model to free GPU memory
171171 try :
172172 # Get current model ID before unloading
173173 current_model = model_manager .current_model
174-
174+
175175 # Unload the model
176176 if hasattr (model_manager , 'unload_model' ):
177177 model_manager .unload_model ()
178178 else :
179179 # Fallback if unload_model method doesn't exist
180180 model_manager .model = None
181181 model_manager .current_model = None
182-
182+
183183 # Clean up memory
184184 if torch .cuda .is_available ():
185185 torch .cuda .empty_cache ()
186186 gc .collect ()
187-
187+
188188 # Log model unloading
189189 if current_model :
190190 log_model_unloaded (current_model )
191-
191+
192192 logger .info ("Model unloaded and memory freed" )
193193 except Exception as e :
194194 logger .error (f"Error during shutdown cleanup: { str (e )} " )
195-
195+
196196 # Clean up any pending tasks
197197 try :
198- tasks = [t for t in asyncio .all_tasks ()
199- if t is not asyncio .current_task () and not t .done ()]
198+ # Get all tasks except the current one
199+ current_task = asyncio .current_task ()
200+ tasks = [t for t in asyncio .all_tasks ()
201+ if t is not current_task and not t .done ()]
202+
200203 if tasks :
201204 logger .debug (f"Cancelling { len (tasks )} remaining tasks" )
205+
206+ # Cancel all tasks
202207 for task in tasks :
203208 task .cancel ()
204- await asyncio .gather (* tasks , return_exceptions = True )
209+
210+ # Wait for tasks to complete with a timeout
211+ try :
212+ # Use wait_for with a timeout to avoid hanging
213+ await asyncio .wait_for (asyncio .gather (* tasks , return_exceptions = True ), timeout = 3.0 )
214+ logger .debug ("All tasks cancelled successfully" )
215+ except asyncio .TimeoutError :
216+ logger .warning ("Timeout waiting for tasks to cancel" )
217+ except asyncio .CancelledError :
218+ # This is expected during shutdown
219+ logger .debug ("Task cancellation was itself cancelled - this is normal during shutdown" )
220+ except Exception as e :
221+ logger .warning (f"Error during task cancellation: { str (e )} " )
205222 except Exception as e :
206223 logger .warning (f"Error cleaning up tasks: { str (e )} " )
207-
224+
208225 # Set server status to stopped
209226 set_server_status ("stopped" )
210-
227+
211228 logger .info (f"{ Fore .GREEN } Server shutdown complete{ Style .RESET_ALL } " )
212-
229+
213230 # Only force exit if this is a true shutdown initiated by SIGINT/SIGTERM
214231 # Check if this was triggered by an actual signal
215232 if hasattr (shutdown_event , 'force_exit_required' ) and shutdown_event .force_exit_required :
216233 import threading
217234 def force_exit ():
218235 import time
219236 import os
220- import signal
221237 time .sleep (3 ) # Give a little time for clean shutdown
222- logger .info ("Forcing exit after shutdown to ensure clean termination" )
223- try :
224- os ._exit (0 ) # Direct exit instead of sending another signal
225- except :
226- pass
227-
228- threading .Thread (target = force_exit , daemon = True ).start ()
238+
239+ # Check if we need to force exit
240+ if hasattr (shutdown_event , 'force_exit_required' ) and shutdown_event .force_exit_required :
241+ logger .info ("Forcing exit after shutdown to ensure clean termination" )
242+ try :
243+ # Reset the flag to avoid multiple exit attempts
244+ shutdown_event .force_exit_required = False
245+ os ._exit (0 ) # Direct exit instead of sending another signal
246+ except :
247+ pass
248+
249+ # Start a daemon thread that will force exit if needed
250+ exit_thread = threading .Thread (target = force_exit , daemon = True )
251+ exit_thread .start ()
229252
230253# Initialize the flag (default to not forcing exit)
231254shutdown_event .force_exit_required = False
@@ -234,40 +257,40 @@ def force_exit():
234257async def add_process_time_header (request : Request , call_next ):
235258 """Middleware to track request processing time"""
236259 start_time = time .time ()
237-
260+
238261 # Extract path and some basic params for logging
239262 path = request .url .path
240263 method = request .method
241264 client = request .client .host if request .client else "unknown"
242-
265+
243266 # Skip detailed logging for health check endpoints to reduce noise
244267 is_health_check = path .endswith ("/health" ) or path .endswith ("/startup-status" )
245-
268+
246269 if not is_health_check :
247270 log_request (f"{ method } { path } " , {"client" : client })
248-
271+
249272 # Process the request
250273 response = await call_next (request )
251-
274+
252275 # Calculate processing time
253276 process_time = time .time () - start_time
254277 response .headers ["X-Process-Time" ] = f"{ process_time :.4f} "
255-
278+
256279 # Add request stats to response headers
257280 response .headers ["X-Request-Count" ] = str (get_request_count ())
258-
281+
259282 # Log slow requests for performance monitoring (if not a health check)
260283 if process_time > 1.0 and not is_health_check :
261284 logger .warning (f"Slow request: { method } { path } took { process_time :.2f} s" )
262-
285+
263286 return response
264287
265288
266289async def load_model_in_background (model_id : str ):
267290 """Load the model asynchronously in the background"""
268291 logger .info (f"Loading model { model_id } in background..." )
269292 start_time = time .time ()
270-
293+
271294 try :
272295 # Ensure HF token is set before loading model
273296 from ..config import get_hf_token
@@ -277,13 +300,13 @@ async def load_model_in_background(model_id: str):
277300 logger .debug ("Using HuggingFace token from configuration" )
278301 else :
279302 logger .warning ("No HuggingFace token found. Some models may not be accessible." )
280-
303+
281304 # Wait for the model to load
282305 await model_manager .load_model (model_id )
283-
306+
284307 # Calculate load time
285308 load_time = time .time () - start_time
286-
309+
287310 # We don't need to call log_model_loaded here since it's already done in the model_manager
288311 logger .info (f"{ Fore .GREEN } Model { model_id } loaded successfully in { load_time :.2f} seconds!{ Style .RESET_ALL } " )
289312 except Exception as e :
0 commit comments