Skip to content

Commit 36e39d6

Browse files
committed
Implement rate limiting
1 parent 95cc013 commit 36e39d6

File tree

6 files changed

+443
-21
lines changed

6 files changed

+443
-21
lines changed

src/kernelbot/api/api_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
SubmissionRequest,
1919
prepare_submission,
2020
)
21+
from libkernelbot.utils import KernelBotError
2122

2223

2324
async def _handle_discord_oauth(code: str, redirect_uri: str) -> tuple[str, str]:
@@ -147,6 +148,8 @@ async def _run_submission(
147148
):
148149
try:
149150
req = prepare_submission(submission, backend)
151+
except KernelBotError as e:
152+
raise HTTPException(status_code=e.http_code, detail=str(e)) from e
150153
except Exception as e:
151154
raise HTTPException(status_code=400, detail=str(e)) from e
152155

src/kernelbot/api/main.py

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
app = FastAPI()
3838

39+
3940
def json_serializer(obj):
4041
"""JSON serializer for objects not serializable by default json code"""
4142
if isinstance(obj, (datetime.datetime, datetime.date, datetime.time)):
@@ -255,10 +256,16 @@ async def cli_auth(auth_provider: str, code: str, state: str, db_context=Depends
255256
raise e
256257
except Exception as e:
257258
# Catch unexpected errors during OAuth handling
258-
raise HTTPException(status_code=500, detail=f"Error during {auth_provider} OAuth flow: {e}") from e
259+
raise HTTPException(
260+
status_code=500,
261+
detail=f"Error during {auth_provider} OAuth flow: {e}",
262+
) from e
259263

260264
if not user_id or not user_name:
261-
raise HTTPException(status_code=500,detail="Failed to retrieve user ID or username from provider.",)
265+
raise HTTPException(
266+
status_code=500,
267+
detail="Failed to retrieve user ID or username from provider.",
268+
)
262269

263270
try:
264271
with db_context as db:
@@ -268,7 +275,10 @@ async def cli_auth(auth_provider: str, code: str, state: str, db_context=Depends
268275
db.create_user_from_cli(user_id, user_name, cli_id, auth_provider)
269276

270277
except AttributeError as e:
271-
raise HTTPException(status_code=500, detail=f"Database interface error during update: {e}") from e
278+
raise HTTPException(
279+
status_code=500,
280+
detail=f"Database interface error during update: {e}",
281+
) from e
272282
except Exception as e:
273283
raise HTTPException(status_code=400, detail=f"Database update failed: {e}") from e
274284

@@ -280,6 +290,7 @@ async def cli_auth(auth_provider: str, code: str, state: str, db_context=Depends
280290
"is_reset": is_reset,
281291
}
282292

293+
283294
async def _stream_submission_response(
284295
submission_request: SubmissionRequest,
285296
submission_mode_enum: SubmissionMode,
@@ -298,18 +309,22 @@ async def _stream_submission_response(
298309

299310
while not task.done():
300311
elapsed_time = time.time() - start_time
301-
yield f"event: status\ndata: {json.dumps({'status': 'processing',
302-
'elapsed_time': round(elapsed_time, 2)},
303-
default=json_serializer)}\n\n"
312+
status_data = json.dumps(
313+
{"status": "processing", "elapsed_time": round(elapsed_time, 2)},
314+
default=json_serializer,
315+
)
316+
yield f"event: status\ndata: {status_data}\n\n"
304317

305318
try:
306319
await asyncio.wait_for(asyncio.shield(task), timeout=15.0)
307320
except asyncio.TimeoutError:
308321
continue
309322
except asyncio.CancelledError:
310-
yield f"event: error\ndata: {json.dumps(
311-
{'status': 'error', 'detail': 'Submission cancelled'},
312-
default=json_serializer)}\n\n"
323+
error_data = json.dumps(
324+
{"status": "error", "detail": "Submission cancelled"},
325+
default=json_serializer,
326+
)
327+
yield f"event: error\ndata: {error_data}\n\n"
313328
return
314329

315330
result, reports = await task
@@ -343,6 +358,7 @@ async def _stream_submission_response(
343358
except asyncio.CancelledError:
344359
pass
345360

361+
346362
@app.post("/{leaderboard_name}/{gpu_type}/{submission_mode}")
347363
async def run_submission( # noqa: C901
348364
leaderboard_name: str,
@@ -381,13 +397,13 @@ async def run_submission( # noqa: C901
381397
)
382398
return StreamingResponse(generator, media_type="text/event-stream")
383399

400+
384401
async def enqueue_background_job(
385402
req: ProcessedSubmissionRequest,
386403
mode: SubmissionMode,
387404
backend: KernelBackend,
388405
manager: BackgroundSubmissionManager,
389406
):
390-
391407
# pre-create the submission for api returns
392408
with backend.db as db:
393409
sub_id = db.create_submission(
@@ -401,7 +417,8 @@ async def enqueue_background_job(
401417
job_id = db.upsert_submission_job_status(sub_id, "initial", None)
402418
# put submission request in queue
403419
await manager.enqueue(req, mode, sub_id)
404-
return sub_id,job_id
420+
return sub_id, job_id
421+
405422

406423
@app.post("/submission/{leaderboard_name}/{gpu_type}/{submission_mode}")
407424
async def run_submission_async(
@@ -425,37 +442,49 @@ async def run_submission_async(
425442
Raises:
426443
HTTPException: If the kernelbot is not initialized, or header/input is invalid.
427444
Returns:
428-
JSONResponse: A JSON response containing job_id and and submission_id for the client to poll for status.
445+
JSONResponse: A JSON response containing job_id and submission_id.
446+
The client can poll for status using these ids.
429447
"""
430448
try:
431-
432449
await simple_rate_limit()
433-
logger.info(f"Received submission request for {leaderboard_name} {gpu_type} {submission_mode}")
434-
450+
logger.info(
451+
"Received submission request for %s %s %s",
452+
leaderboard_name,
453+
gpu_type,
454+
submission_mode,
455+
)
435456

436457
# throw error if submission request is invalid
437458
try:
438459
submission_request, submission_mode_enum = await to_submit_info(
439-
user_info, submission_mode, file, leaderboard_name, gpu_type, db_context
460+
user_info, submission_mode, file, leaderboard_name, gpu_type, db_context
440461
)
441462

442463
req = prepare_submission(submission_request, backend_instance)
443464

465+
except KernelBotError as e:
466+
raise HTTPException(status_code=e.http_code, detail=str(e)) from e
444467
except Exception as e:
445-
raise HTTPException(status_code=400, detail=f"failed to prepare submission request: {str(e)}") from e
468+
raise HTTPException(
469+
status_code=400,
470+
detail=f"failed to prepare submission request: {str(e)}",
471+
) from e
446472

447473
# prepare submission request before the submission is started
448474
if not req.gpus or len(req.gpus) != 1:
449475
raise HTTPException(status_code=400, detail="Invalid GPU type")
450476

451477
# put submission request to background manager to run in background
452-
sub_id,job_status_id = await enqueue_background_job(
478+
sub_id, job_status_id = await enqueue_background_job(
453479
req, submission_mode_enum, backend_instance, background_submission_manager
454480
)
455481

456482
return JSONResponse(
457483
status_code=202,
458-
content={"details":{"id": sub_id, "job_status_id": job_status_id}, "status": "accepted"},
484+
content={
485+
"details": {"id": sub_id, "job_status_id": job_status_id},
486+
"status": "accepted",
487+
},
459488
)
460489
# Preserve FastAPI HTTPException as-is
461490
except HTTPException:
@@ -470,6 +499,7 @@ async def run_submission_async(
470499
logger.error(f"Unexpected error in api submissoin: {e}")
471500
raise HTTPException(status_code=500, detail="Internal server error") from e
472501

502+
473503
@app.get("/leaderboards")
474504
async def get_leaderboards(db_context=Depends(get_db)):
475505
"""An endpoint that returns all leaderboards.

src/kernelbot/cogs/admin_cog.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,16 @@ def __init__(self, bot: "ClusterBot"):
122122
name="set-forum-ids", description="Sets forum IDs"
123123
)(self.set_forum_ids)
124124

125+
self.set_submission_rate_limit = bot.admin_group.command(
126+
name="set-submission-rate-limit",
127+
description="Set default or per-user submission rate limit (submissions/minute).",
128+
)(self.set_submission_rate_limit)
129+
130+
self.get_submission_rate_limit = bot.admin_group.command(
131+
name="get-submission-rate-limit",
132+
description="Get default or per-user submission rate limit (submissions/minute).",
133+
)(self.get_submission_rate_limit)
134+
125135
self._scheduled_cleanup_temp_users.start()
126136

127137
# --------------------------------------------------------------------------
@@ -512,6 +522,153 @@ async def start(self, interaction: discord.Interaction):
512522
interaction, "Bot will accept submissions again!", ephemeral=True
513523
)
514524

525+
def _parse_user_id_arg(self, user_id: str) -> str:
526+
"""Accepts a raw id or a discord mention and returns the id string."""
527+
s = (user_id or "").strip()
528+
if s.startswith("<@") and s.endswith(">"):
529+
s = s[2:-1].strip()
530+
if s.startswith("!"):
531+
s = s[1:].strip()
532+
return s
533+
534+
def _format_rate(self, rate: float | None) -> str:
535+
if rate is None:
536+
return "unlimited"
537+
r = float(rate)
538+
if r == 0:
539+
return "blocked"
540+
return f"{r:g}/min"
541+
542+
@app_commands.describe(
543+
rate_per_minute="Rate in submissions/minute. Use 'none' for unlimited; 'default' clears a user override.", # noqa: E501
544+
user_id="Optional user id or mention. If omitted, sets the default.",
545+
)
546+
@with_error_handling
547+
async def set_submission_rate_limit(
548+
self,
549+
interaction: discord.Interaction,
550+
rate_per_minute: str,
551+
user_id: Optional[str] = None,
552+
):
553+
is_admin = await self.admin_check(interaction)
554+
if not is_admin:
555+
await send_discord_message(
556+
interaction,
557+
"You need to be Admin to use this command.",
558+
ephemeral=True,
559+
)
560+
return
561+
562+
rate_s = (rate_per_minute or "").strip().lower()
563+
if rate_s in {"none", "unlimited", "off"}:
564+
parsed_rate: float | None = None
565+
clear_override = False
566+
elif rate_s == "default":
567+
parsed_rate = None
568+
clear_override = True
569+
else:
570+
try:
571+
parsed_rate = float(rate_s)
572+
except ValueError:
573+
await send_discord_message(
574+
interaction,
575+
"Invalid rate. Use a number like `1` or `0.5`, or `none`.",
576+
ephemeral=True,
577+
)
578+
return
579+
if parsed_rate < 0:
580+
await send_discord_message(
581+
interaction,
582+
"Invalid rate. Must be >= 0 (or `none`).",
583+
ephemeral=True,
584+
)
585+
return
586+
clear_override = False
587+
588+
with self.bot.leaderboard_db as db:
589+
if user_id is None or user_id.strip() == "":
590+
if clear_override:
591+
await send_discord_message(
592+
interaction,
593+
"For default limit, use a number or `none` (not `default`).",
594+
ephemeral=True,
595+
)
596+
return
597+
db.set_default_submission_rate_limit(parsed_rate)
598+
await send_discord_message(
599+
interaction,
600+
f"Default submission rate limit set to **{self._format_rate(parsed_rate)}**.",
601+
ephemeral=True,
602+
)
603+
return
604+
605+
uid = self._parse_user_id_arg(user_id)
606+
if uid == "":
607+
await send_discord_message(
608+
interaction,
609+
"Invalid user id.",
610+
ephemeral=True,
611+
)
612+
return
613+
614+
if clear_override:
615+
db.clear_user_submission_rate_limit(uid)
616+
await send_discord_message(
617+
interaction,
618+
f"Cleared per-user submission rate limit override for `{uid}` (default now applies).", # noqa: E501
619+
ephemeral=True,
620+
)
621+
return
622+
623+
db.set_user_submission_rate_limit(uid, parsed_rate)
624+
await send_discord_message(
625+
interaction,
626+
f"Submission rate limit for `{uid}` set to **{self._format_rate(parsed_rate)}**.",
627+
ephemeral=True,
628+
)
629+
630+
@app_commands.describe(
631+
user_id="Optional user id or mention. If omitted, shows the default.",
632+
)
633+
@with_error_handling
634+
async def get_submission_rate_limit(
635+
self,
636+
interaction: discord.Interaction,
637+
user_id: Optional[str] = None,
638+
):
639+
is_admin = await self.admin_check(interaction)
640+
if not is_admin:
641+
await send_discord_message(
642+
interaction,
643+
"You need to be Admin to use this command.",
644+
ephemeral=True,
645+
)
646+
return
647+
648+
with self.bot.leaderboard_db as db:
649+
if user_id is None or user_id.strip() == "":
650+
default_rate, capacity = db.get_default_submission_rate_limit()
651+
msg = (
652+
f"Default submission rate limit: **{self._format_rate(default_rate)}**\n"
653+
f"Bucket capacity: **{capacity:g}**"
654+
)
655+
await send_discord_message(interaction, msg, ephemeral=True)
656+
return
657+
658+
uid = self._parse_user_id_arg(user_id)
659+
effective, has_override, user_rate, default_rate, capacity = (
660+
db.get_submission_rate_limits(uid)
661+
)
662+
override_text = "default" if not has_override else self._format_rate(user_rate)
663+
msg = (
664+
f"User `{uid}` submission rate limit:\n"
665+
f"- Effective: **{self._format_rate(effective)}**\n"
666+
f"- User override: **{override_text}**\n"
667+
f"- Default: **{self._format_rate(default_rate)}**\n"
668+
f"- Bucket capacity: **{capacity:g}**"
669+
)
670+
await send_discord_message(interaction, msg, ephemeral=True)
671+
515672
@app_commands.describe(
516673
problem_set="Which problem set to load.",
517674
repository_name="Name of the repository to load problems from (in format: user/repo)",

0 commit comments

Comments
 (0)