-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlb.py
More file actions
executable file
·103 lines (73 loc) · 2.8 KB
/
lb.py
File metadata and controls
executable file
·103 lines (73 loc) · 2.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#!python3.12
import argparse
import http.server
import sys
import threading
import time
from http import HTTPStatus
import requests
from lb.utils import set_headers, log_message
servers: dict[str, bool] | None = None
DEFAULT_HEALTHCHECK_INTERVAL: int = 10
DEFAULT_PORT = 8000
class Handler(http.server.SimpleHTTPRequestHandler):
_last_used_server = 0
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def log_message(self, fmt, *args):
log_message(self, fmt, *args)
@staticmethod
def _get_next_server():
servers_list = list(servers.keys())
for i in range(Handler._last_used_server + 1, len(servers)):
if servers[servers_list[i]]:
Handler._last_used_server = i
return servers_list[i]
for i in range(Handler._last_used_server + 1):
if servers[servers_list[i]]:
Handler._last_used_server = i
return servers_list[i]
raise RuntimeError("No active backend")
def do_GET(self):
server = self._get_next_server()
print(f"Forwarding request to {server}\n")
response = requests.get(server)
print(f"Response from server: {HTTPStatus(response.status_code)}: {response.text}\n")
set_headers(self)
self.wfile.write(response.content)
def do_healthcheck(server):
try:
response = requests.get(server)
response.raise_for_status()
servers[server] = True
except Exception as exc:
print(f"{server} failed healthcheck. Exception: {exc}")
servers[server] = False
def do_healthchecks(healthcheck_interval):
while True:
print("Healthchecking...")
[do_healthcheck(server) for server in servers.keys()]
time.sleep(healthcheck_interval)
def run_healthcheck_thread(healthcheck_interval: int):
assert healthcheck_interval > 0
thread = threading.Thread(target=do_healthchecks, args=(healthcheck_interval,))
thread.start()
def run_server(port: int):
assert port > 0
try:
with http.server.ThreadingHTTPServer(("", port), Handler) as httpd:
print(f"Load-balancer server listening on port {port}...")
httpd.serve_forever()
except KeyboardInterrupt:
print("\nServer stopped by user.")
def parse_args(args):
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=DEFAULT_PORT)
parser.add_argument("--healthcheck-interval", "-T", type=int, default=DEFAULT_HEALTHCHECK_INTERVAL)
parser.add_argument("servers", nargs="+", type=str)
return parser.parse_args(args)
if __name__ == "__main__":
args = parse_args(sys.argv[1:])
servers = {server: True for server in args.servers}
run_healthcheck_thread(args.healthcheck_interval)
run_server(args.port)