Skip to content

Commit dabb47f

Browse files
authored
Fix login issue (#180)
* fix https cookies & secure * fix websocked db hoarding
1 parent 57c9911 commit dabb47f

File tree

5 files changed

+85
-21
lines changed

5 files changed

+85
-21
lines changed

backend/app/api/auth.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@
2626

2727
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/login", auto_error=False)
2828

29+
30+
def _should_use_secure_cookie(request: Request) -> bool:
31+
if request.url.scheme == "https":
32+
return True
33+
forwarded_proto = request.headers.get("x-forwarded-proto", "").split(",", 1)[0].strip().lower()
34+
return forwarded_proto == "https"
35+
2936
def create_access_token(data: dict, expires_delta: Optional[datetime.timedelta] = None):
3037
to_encode = data.copy()
3138
if expires_delta:
@@ -168,12 +175,12 @@ async def login(
168175
max_age=max_age,
169176
httponly=True,
170177
samesite="lax",
171-
secure=not (config.DEBUG or config.TEST),
178+
secure=_should_use_secure_cookie(request),
172179
)
173180
return {"access_token": access_token, "token_type": "bearer"}
174181

175182
@api_no_auth.post("/signup")
176-
async def signup(data: UserSignup, response: Response, db: Session = Depends(get_db)):
183+
async def signup(request: Request, data: UserSignup, response: Response, db: Session = Depends(get_db)):
177184
if config.HASSIO_RUN_MODE is not None:
178185
raise HTTPException(status_code=401, detail="Signup not allowed with HASSIO_RUN_MODE")
179186

@@ -218,7 +225,7 @@ async def signup(data: UserSignup, response: Response, db: Session = Depends(get
218225
max_age=max_age,
219226
httponly=True,
220227
samesite="lax",
221-
secure=not (config.DEBUG or config.TEST),
228+
secure=_should_use_secure_cookie(request),
222229
)
223230
return {"success": True, "access_token": access_token, "token_type": "bearer"}
224231

backend/app/api/tests/test_auth.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,3 +173,25 @@ async def test_cookie_auth_on_protected_route(no_auth_client, db: Session):
173173
no_auth_client.headers.pop("Authorization", None)
174174
response = await no_auth_client.get("/api/system/metrics")
175175
assert response.status_code == HTTP_200_OK
176+
177+
178+
@pytest.mark.asyncio
179+
async def test_login_cookie_secure_flag_depends_on_request_scheme(no_auth_client, db: Session):
180+
user = User(email="securecookie@example.com")
181+
user.set_password("testpassword")
182+
db.add(user)
183+
db.commit()
184+
185+
login_data = {"username": "securecookie@example.com", "password": "testpassword"}
186+
187+
http_response = await no_auth_client.post("/api/login", data=login_data)
188+
assert http_response.status_code == HTTP_200_OK
189+
assert "Secure" not in http_response.headers.get("set-cookie", "")
190+
191+
https_response = await no_auth_client.post(
192+
"/api/login",
193+
data=login_data,
194+
headers={"x-forwarded-proto": "https"},
195+
)
196+
assert https_response.status_code == HTTP_200_OK
197+
assert "Secure" in https_response.headers.get("set-cookie", "")

backend/app/websockets.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
import json
33
from typing import List
44
from redis.asyncio import from_url as create_redis, Redis
5-
from fastapi import WebSocket, WebSocketDisconnect, Depends
6-
from sqlalchemy.orm import Session
7-
from app.database import get_db
5+
from fastapi import WebSocket, WebSocketDisconnect
6+
7+
from app.database import SessionLocal
88

99
from app.config import config
1010

@@ -69,12 +69,17 @@ async def publish_message(redis: Redis, event: str, data: dict):
6969

7070
def register_ws_routes(app):
7171
@app.websocket("/ws")
72-
async def websocket_endpoint(websocket: WebSocket, db: Session = Depends(get_db)):
72+
async def websocket_endpoint(websocket: WebSocket):
7373
# Full access in the HASSIO ingress mode
7474
if config.HASSIO_RUN_MODE != "ingress":
7575
from app.api.auth import get_current_user_from_websocket
7676

77-
user, error_reason = get_current_user_from_websocket(websocket, db)
77+
db = SessionLocal()
78+
try:
79+
user, error_reason = get_current_user_from_websocket(websocket, db)
80+
finally:
81+
db.close()
82+
7883
if user is None:
7984
await websocket.close(code=1008, reason=error_reason or "Could not validate credentials")
8085
return

backend/app/ws/agent_ws.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,10 @@
1414

1515
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect, status
1616
from starlette.websockets import WebSocketState
17-
from sqlalchemy.orm import Session
1817
from arq import ArqRedis as Redis
1918

2019
# ── project locals ──────────────────────────────────────────────────────────
21-
from app.database import get_db
20+
from app.database import SessionLocal
2221
from app.redis import get_redis, create_redis_connection # create_… gives a raw conn
2322
from app.websockets import publish_message
2423
from app.models.frame import Frame
@@ -35,6 +34,14 @@
3534
active_sockets_by_frame: dict[int, list[WebSocket]] = {}
3635
active_sockets: set[WebSocket] = set()
3736

37+
38+
async def write_log(redis: Redis, frame_id: int, type: str, line: str, ip: str | None = None):
39+
db = SessionLocal()
40+
try:
41+
await log(db, redis, frame_id, type, line, ip=ip)
42+
finally:
43+
db.close()
44+
3845
# ────────────────────────────────────────────────────────────────────────────
3946
# tiny helpers
4047
# ────────────────────────────────────────────────────────────────────────────
@@ -264,7 +271,6 @@ async def pump_commands(
264271
@router.websocket("/ws/agent")
265272
async def ws_agent_endpoint(
266273
ws: WebSocket,
267-
db: Session = Depends(get_db),
268274
redis: Redis = Depends(get_redis),
269275
):
270276
# ----- rudimentary DoS guard (per-worker) ------------------------------
@@ -286,7 +292,12 @@ async def ws_agent_endpoint(
286292
return
287293

288294
server_api_key = str(hello_msg.get("serverApiKey", "")) or ""
289-
frame = db.query(Frame).filter_by(server_api_key=server_api_key).first()
295+
db = SessionLocal()
296+
try:
297+
frame = db.query(Frame).filter_by(server_api_key=server_api_key).first()
298+
finally:
299+
db.close()
300+
290301
if frame is None:
291302
await ws.close(code=status.WS_1008_POLICY_VIOLATION, reason="unknown frame")
292303
return
@@ -340,7 +351,7 @@ async def ws_agent_endpoint(
340351
{"active_connections": await number_of_connections_for_frame(redis, frame.id),
341352
"id": frame.id}
342353
)
343-
await log(db, redis, frame.id, "agent", f'☎️ Frame "{frame.name}" connected ☎️', ip=client_ip)
354+
await write_log(redis, frame.id, "agent", f'☎️ Frame "{frame.name}" connected ☎️', ip=client_ip)
344355

345356
# =======================================================================
346357
# RECEIVE LOOP
@@ -433,7 +444,7 @@ async def ws_agent_endpoint(
433444

434445
for line in data.splitlines():
435446
if line:
436-
await log(db, redis, frame.id, stream, line, ip=client_ip)
447+
await write_log(redis, frame.id, stream, line, ip=client_ip)
437448
await redis.rpush(STREAM_KEY.format(id=pl["id"]),
438449
json.dumps({"stream": stream, "data": line}).encode())
439450
await redis.expire(STREAM_KEY.format(id=pl["id"]), 300)
@@ -463,4 +474,4 @@ async def ws_agent_endpoint(
463474
{"active_connections": await number_of_connections_for_frame(redis, frame.id),
464475
"id": frame.id}
465476
)
466-
await log(db, redis, frame.id, "agent", f'👋 Frame "{frame.name}" disconnected 👋', ip=client_ip)
477+
await write_log(redis, frame.id, "agent", f'👋 Frame "{frame.name}" disconnected 👋', ip=client_ip)

backend/app/ws/terminal_ws.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from sqlalchemy.orm import Session
77
from arq import ArqRedis as Redis
88

9-
from app.database import get_db
9+
from app.database import SessionLocal
1010
from app.redis import get_redis
1111
from app.models.frame import Frame
1212
from app.api.auth import get_current_user_from_websocket
@@ -18,22 +18,36 @@
1818
async def ssh_terminal(
1919
websocket: WebSocket,
2020
frame_id: int,
21-
db: Session = Depends(get_db),
2221
redis: Redis = Depends(get_redis),
2322
):
24-
user, error_reason = get_current_user_from_websocket(websocket, db)
23+
db: Session = SessionLocal()
24+
try:
25+
user, error_reason = get_current_user_from_websocket(websocket, db)
26+
finally:
27+
db.close()
28+
2529
if user is None:
2630
await websocket.close(code=1008, reason=error_reason or "Could not validate credentials")
2731
return
2832

2933
await websocket.accept()
3034

31-
frame = db.query(Frame).filter(Frame.id == frame_id).first()
35+
db = SessionLocal()
36+
try:
37+
frame = db.query(Frame).filter(Frame.id == frame_id).first()
38+
finally:
39+
db.close()
40+
3241
if frame is None:
3342
await websocket.close(code=1008, reason="Frame not found")
3443
return
3544

36-
ssh = await get_ssh_connection(db, redis, frame)
45+
db = SessionLocal()
46+
try:
47+
ssh = await get_ssh_connection(db, redis, frame)
48+
finally:
49+
db.close()
50+
3751
proc = await ssh.create_process(term_type="xterm", encoding="utf-8")
3852

3953
async def pipe(reader):
@@ -61,4 +75,9 @@ async def pipe(reader):
6175
with contextlib.suppress(Exception):
6276
proc.stdin.write_eof()
6377
await proc.wait_closed()
64-
await remove_ssh_connection(db, redis, ssh, frame)
78+
79+
db = SessionLocal()
80+
try:
81+
await remove_ssh_connection(db, redis, ssh, frame)
82+
finally:
83+
db.close()

0 commit comments

Comments
 (0)