88import gc
99import torch
1010import os
11- from fastapi import FastAPI , Request
11+ from fastapi import FastAPI , Request , HTTPException , Depends , BackgroundTasks
1212from fastapi .middleware .cors import CORSMiddleware
1313from contextlib import contextmanager
1414from colorama import Fore , Style
15+ from typing import Optional , Dict , Any , List
1516
1617# Try to import FastAPICache, but don't fail if not available
1718try :
@@ -29,7 +30,7 @@ def init(backend, **kwargs):
2930
3031from .. import __version__
3132from ..logger import get_logger
32- from ..logger .logger import log_request , log_model_loaded , log_model_unloaded , get_request_count
33+ from ..logger .logger import log_request , log_model_loaded , log_model_unloaded , get_request_count , set_server_status
3334from ..model_manager import ModelManager
3435from ..config import (
3536 ENABLE_CORS ,
@@ -38,8 +39,11 @@ def init(backend, **kwargs):
3839 ENABLE_COMPRESSION ,
3940 QUANTIZATION_TYPE ,
4041 SERVER_PORT ,
42+ DEFAULT_MAX_LENGTH ,
43+ get_env_var ,
4144)
4245from ..cli .config import get_config_value
46+ from ..utils .system import get_system_resources
4347
4448# Get the logger
4549logger = get_logger ("locallab.app" )
@@ -77,10 +81,29 @@ def init(backend, **kwargs):
7781app .include_router (generate_router )
7882app .include_router (system_router )
7983
84+ # Startup event triggered flag
85+ startup_event_triggered = False
8086
87+ # Application startup event to ensure banners are displayed
8188@app .on_event ("startup" )
8289async def startup_event ():
83- """Initialization tasks when the server starts"""
90+ """Event that is triggered when the application starts up"""
91+ global startup_event_triggered
92+
93+ # Only log once
94+ if startup_event_triggered :
95+ return
96+
97+ logger .info ("FastAPI application startup event triggered" )
98+ startup_event_triggered = True
99+
100+ # Wait a short time to ensure logs are processed
101+ await asyncio .sleep (0.5 )
102+
103+ # Log a special message that our callback handler will detect
104+ root_logger = logging .getLogger ()
105+ root_logger .info ("Application startup complete - banner display trigger" )
106+
84107 logger .info (f"{ Fore .CYAN } Starting LocalLab server...{ Style .RESET_ALL } " )
85108
86109 # Get HuggingFace token and set it in environment if available
@@ -158,7 +181,8 @@ async def shutdown_event():
158181 model_manager .current_model = None
159182
160183 # Clean up memory
161- torch .cuda .empty_cache ()
184+ if torch .cuda .is_available ():
185+ torch .cuda .empty_cache ()
162186 gc .collect ()
163187
164188 # Log model unloading
@@ -169,7 +193,37 @@ async def shutdown_event():
169193 except Exception as e :
170194 logger .error (f"Error during shutdown cleanup: { str (e )} " )
171195
196+ # Clean up any pending tasks
197+ try :
198+ tasks = [t for t in asyncio .all_tasks ()
199+ if t is not asyncio .current_task () and not t .done ()]
200+ if tasks :
201+ logger .debug (f"Cancelling { len (tasks )} remaining tasks" )
202+ for task in tasks :
203+ task .cancel ()
204+ await asyncio .gather (* tasks , return_exceptions = True )
205+ except Exception as e :
206+ logger .warning (f"Error cleaning up tasks: { str (e )} " )
207+
208+ # Set server status to stopped
209+ set_server_status ("stopped" )
210+
172211 logger .info (f"{ Fore .GREEN } Server shutdown complete{ Style .RESET_ALL } " )
212+
213+ # Force exit if needed to clean up any hanging resources
214+ import threading
215+ def force_exit ():
216+ import time
217+ import os
218+ import signal
219+ time .sleep (3 ) # Give a little time for clean shutdown
220+ logger .info ("Forcing exit after shutdown to ensure clean termination" )
221+ try :
222+ os .kill (os .getpid (), signal .SIGTERM )
223+ except :
224+ os ._exit (0 )
225+
226+ threading .Thread (target = force_exit , daemon = True ).start ()
173227
174228
175229@app .middleware ("http" )
0 commit comments