@@ -205,11 +205,23 @@ def __init__(self, config):
205205 self ._serve_task = None
206206 self ._socket = None
207207 self .app = config .app
208+ self .callback_triggered = False
208209
209210 async def start (self ):
210211 self .started = True
211212 logger .info ("Started SimpleTCPServer as fallback" )
212213
214+ # Trigger callback if server has one and it hasn't been triggered yet
215+ if (hasattr (self , 'server' ) and self .server and
216+ hasattr (self .server , 'on_startup_callback' ) and
217+ not self .callback_triggered and
218+ not (hasattr (self .server , 'callback_triggered' ) and self .server .callback_triggered )):
219+ logger .info ("Executing startup callback from SimpleTCPServer.start" )
220+ self .server .on_startup_callback ()
221+ self .callback_triggered = True
222+ if hasattr (self .server , 'callback_triggered' ):
223+ self .server .callback_triggered = True
224+
213225 if not self ._serve_task :
214226 self ._serve_task = asyncio .create_task (self ._run_server ())
215227
@@ -428,6 +440,7 @@ def __init__(self, config):
428440 super ().__init__ (config )
429441 self .servers = [] # Initialize servers list
430442 self .should_exit = False
443+ self .callback_triggered = False # Flag to track if callback has been triggered
431444
432445 def install_signal_handlers (self ):
433446 def handle_exit (signum , frame ):
@@ -444,21 +457,43 @@ async def startup(self, sockets=None):
444457 try :
445458 await super ().startup (sockets = sockets )
446459 logger .info ("Using uvicorn's built-in Server implementation" )
460+
461+ # Execute callback after successful startup
462+ # This is critical to show the running banner
463+ if hasattr (self , 'on_startup_callback' ) and not self .callback_triggered :
464+ logger .info ("Executing server startup callback" )
465+ self .on_startup_callback ()
466+ self .callback_triggered = True
467+
447468 except Exception as e :
448469 logger .error (f"Error during server startup: { str (e )} " )
449470 logger .debug (f"Server startup error details: { traceback .format_exc ()} " )
450471 self .servers = []
472+
473+ # Create SimpleTCPServer as fallback
451474 if sockets :
452475 for socket in sockets :
453476 server = SimpleTCPServer (config = self .config )
454477 server .server = self
455478 await server .start ()
456479 self .servers .append (server )
480+
481+ # Make sure callback is executed for the fallback server too
482+ if hasattr (self , 'on_startup_callback' ) and not self .callback_triggered :
483+ logger .info ("Executing server startup callback (fallback server)" )
484+ self .on_startup_callback ()
485+ self .callback_triggered = True
457486 else :
458487 server = SimpleTCPServer (config = self .config )
459488 server .server = self
460489 await server .start ()
461490 self .servers .append (server )
491+
492+ # Make sure callback is executed for the fallback server too
493+ if hasattr (self , 'on_startup_callback' ) and not self .callback_triggered :
494+ logger .info ("Executing server startup callback (fallback server)" )
495+ self .on_startup_callback ()
496+ self .callback_triggered = True
462497
463498 async def shutdown (self , sockets = None ):
464499 logger .debug ("Starting server shutdown process" )
@@ -562,15 +597,17 @@ def start_server(use_ngrok: bool = None, port: int = None, ngrok_auth_token: Opt
562597 logger .error (f"{ Fore .YELLOW } Please ensure all dependencies are installed: pip install -e .{ Style .RESET_ALL } " )
563598 raise
564599
565- # Create a function to display the Running banner when the server is ready
566- startup_complete = False # Flag to track if startup has been completed
600+ # Flag to track if startup has been completed
601+ startup_complete = [ False ] # Using a list as a mutable reference
567602
568603 def on_startup ():
569- nonlocal startup_complete
570- if startup_complete :
604+ # Use the mutable reference to track startup
605+ if startup_complete [ 0 ] :
571606 return
572607
573608 try :
609+ logger .info ("Server startup callback triggered" )
610+
574611 # Set server status to running
575612 set_server_status ("running" )
576613
@@ -614,12 +651,14 @@ def on_startup():
614651 logger .debug (f"Footer display error details: { traceback .format_exc ()} " )
615652
616653 # Set flag to indicate startup is complete
617- startup_complete = True
654+ startup_complete [0 ] = True
655+ logger .info ("Server startup display completed successfully" )
656+
618657 except Exception as e :
619658 logger .error (f"Error during server startup display: { str (e )} " )
620659 logger .debug (f"Startup display error details: { traceback .format_exc ()} " )
621660 # Still mark startup as complete to avoid repeated attempts
622- startup_complete = True
661+ startup_complete [ 0 ] = True
623662 # Ensure server status is set to running even if display fails
624663 set_server_status ("running" )
625664
@@ -639,21 +678,21 @@ def on_startup():
639678
640679 # Define the callback for Colab
641680 async def on_startup_async ():
642- # This will only run once due to the flag in on_startup
643- on_startup ()
681+ # This is an async callback that uvicorn might call
682+ if not startup_complete [0 ]:
683+ on_startup ()
644684
645685 config = uvicorn .Config (
646686 app ,
647687 host = "0.0.0.0" , # Bind to all interfaces in Colab
648688 port = port ,
649689 reload = False ,
650690 log_level = "info" ,
651- # Use an async callback function, not a list
652- callback_notify = on_startup_async
691+ callback_notify = [on_startup_async ] # Use a list for the callback
653692 )
654693
655694 server = ServerWithCallback (config )
656- server .on_startup_callback = on_startup # Set the callback
695+ server .on_startup_callback = on_startup # Also set the direct callback
657696
658697 # Use the appropriate event loop method based on Python version
659698 try :
@@ -664,7 +703,8 @@ async def on_startup_async():
664703 if "'Server' object has no attribute 'start'" in str (e ):
665704 # If we get the 'start' attribute error, use our SimpleTCPServer directly
666705 logger .warning ("Falling back to direct SimpleTCPServer implementation" )
667- direct_server = SimpleTCPServer (config = self .config )
706+ direct_server = SimpleTCPServer (config = config ) # Pass the config directly
707+ direct_server .server = server # Set reference to the server for callbacks
668708 asyncio .run (direct_server .serve ())
669709 else :
670710 raise
@@ -679,7 +719,8 @@ async def on_startup_async():
679719 if "'Server' object has no attribute 'start'" in str (e ):
680720 # If we get the 'start' attribute error, use our SimpleTCPServer directly
681721 logger .warning ("Falling back to direct SimpleTCPServer implementation" )
682- direct_server = SimpleTCPServer (config = self .config )
722+ direct_server = SimpleTCPServer (config = config ) # Pass the config directly
723+ direct_server .server = server # Set reference to the server for callbacks
683724 loop .run_until_complete (direct_server .serve ())
684725 else :
685726 raise
@@ -698,12 +739,11 @@ async def on_startup_async():
698739 reload = False ,
699740 workers = 1 ,
700741 log_level = "info" ,
701- # This won't be used directly, as we call on_startup in the ServerWithCallback class
702- callback_notify = None
742+ callback_notify = [lambda : on_startup ()] # Use a lambda to prevent immediate execution
703743 )
704744
705745 server = ServerWithCallback (config )
706- server .on_startup_callback = on_startup # Set the callback
746+ server .on_startup_callback = on_startup # Set the callback directly
707747
708748 # Use asyncio.run which is more reliable
709749 try :
@@ -714,7 +754,8 @@ async def on_startup_async():
714754 if "'Server' object has no attribute 'start'" in str (e ):
715755 # If we get the 'start' attribute error, use our SimpleTCPServer directly
716756 logger .warning ("Falling back to direct SimpleTCPServer implementation" )
717- direct_server = SimpleTCPServer (config = self .config )
757+ direct_server = SimpleTCPServer (config = config ) # Pass the config directly
758+ direct_server .server = server # Set reference to the server for callbacks
718759 asyncio .run (direct_server .serve ())
719760 else :
720761 raise
@@ -729,13 +770,20 @@ async def on_startup_async():
729770 if "'Server' object has no attribute 'start'" in str (e ):
730771 # If we get the 'start' attribute error, use our SimpleTCPServer directly
731772 logger .warning ("Falling back to direct SimpleTCPServer implementation" )
732- direct_server = SimpleTCPServer (config = self .config )
773+ direct_server = SimpleTCPServer (config = config ) # Pass the config directly
774+ direct_server .server = server # Set reference to the server for callbacks
733775 loop .run_until_complete (direct_server .serve ())
734776 else :
735777 raise
736778 else :
737779 # Re-raise other errors
738780 raise
781+
782+ # If we reach here and startup hasn't completed yet, call it manually as a fallback
783+ if not startup_complete [0 ]:
784+ logger .warning ("Server started but startup callback wasn't triggered. Calling manually..." )
785+ on_startup ()
786+
739787 except Exception as e :
740788 logger .error (f"Server startup failed: { str (e )} " )
741789 logger .error (traceback .format_exc ())
0 commit comments