from __future__ import annotations import asyncio from pathlib import Path import httpx import websockets from config import ProxyRoute from docker_manager import DockerManager # Hop-by-hop headers stored as bytes – avoids .decode() on every header (RFC 2616 §13.5.1) _HOP_BY_HOP = { b"connection", b"keep-alive", b"proxy-authenticate", b"proxy-authorization", b"te", b"trailers", b"transfer-encoding", b"upgrade", } _TEMPLATES_DIR = Path(__file__).parent / "templates" _template_cache: dict[str, str] = {} # filename -> raw text, cached on first disk read def _load_template(name: str, **replacements: str) -> bytes: if name not in _template_cache: _template_cache[name] = (_TEMPLATES_DIR / name).read_text(encoding="utf-8") text = _template_cache[name] for key, value in replacements.items(): text = text.replace("{{" + key + "}}", value) return text.encode() class ProxyMiddleware: def __init__(self, app, routes: dict[str, ProxyRoute], docker: DockerManager, http_client: httpx.AsyncClient): self.app = app self.routes = routes self.docker = docker self._client = http_client async def __call__(self, scope, receive, send): if scope["type"] == "http": await self._http(scope, receive, send) elif scope["type"] == "websocket": await self._websocket(scope, receive, send) else: await self.app(scope, receive, send) # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ def _find_route(self, scope) -> ProxyRoute | None: headers = dict(scope.get("headers", [])) host = headers.get(b"host", b"").decode().split(":")[0] return self.routes.get(host) async def _ensure_containers(self, route: ProxyRoute): for name in route.containers: await self.docker.ensure_running(name, route.load_seconds) self.docker.reset_idle(name, route.timeout_seconds) # ------------------------------------------------------------------ # HTTP proxy # ------------------------------------------------------------------ async def _http(self, scope, receive, send): if scope["path"] == "/health": await send({"type": "http.response.start", "status": 200, "headers": [(b"content-type", b"application/json")]}) await send({"type": "http.response.body", "body": b'{"status":"ok"}'}) return route = self._find_route(scope) if not route: headers = dict(scope.get("headers", [])) host = headers.get(b"host", b"").decode().split(":")[0] body = _load_template("404.html", SERVICE=host) await send({"type": "http.response.start", "status": 404, "headers": [(b"content-type", b"text/html; charset=utf-8"), (b"content-length", str(len(body)).encode())]}) await send({"type": "http.response.body", "body": body}) return # If any container is not yet running, show wait page and start it # in the background. The browser will auto-refresh via the Refresh header. for name in route.containers: if not await self.docker.is_running(name): asyncio.create_task(self._ensure_containers(route)) body = _load_template("wait.html") await send({"type": "http.response.start", "status": 503, "headers": [ (b"content-type", b"text/html; charset=utf-8"), (b"refresh", b"3"), (b"content-length", str(len(body)).encode()), ]}) await send({"type": "http.response.body", "body": body}) return # All containers running – reset idle timers and proxy the request. for name in route.containers: self.docker.reset_idle(name, route.timeout_seconds) path = scope["path"] query = scope.get("query_string", b"").decode() url = f"http://{route.target_host}:{route.target_port}{path}" if query: url += f"?{query}" # bytes comparison – no .decode() needed per header req_headers = [ (k, v) for k, v in scope["headers"] if k.lower() not in _HOP_BY_HOP ] body = b"" while True: msg = await receive() body += msg.get("body", b"") if not msg.get("more_body", False): break try: resp = await self._client.request( method=scope["method"], url=url, headers=req_headers, content=body, ) # Strip hop-by-hop + content-encoding/length (httpx decompresses # automatically so the original values would be wrong). skip = _HOP_BY_HOP | {b"content-encoding", b"content-length"} resp_headers = [ (k.lower(), v) for k, v in resp.headers.raw if k.lower() not in skip ] resp_headers.append( (b"content-length", str(len(resp.content)).encode()) ) await send({"type": "http.response.start", "status": resp.status_code, "headers": resp_headers}) await send({"type": "http.response.body", "body": resp.content}) except Exception as e: print(f"[http] Upstream error for {url}: {e}") await send({"type": "http.response.start", "status": 502, "headers": [(b"content-type", b"text/plain")]}) await send({"type": "http.response.body", "body": b"Bad Gateway"}) # ------------------------------------------------------------------ # WebSocket proxy # ------------------------------------------------------------------ async def _websocket(self, scope, receive, send): route = self._find_route(scope) if not route: await send({"type": "websocket.close", "code": 4004}) return try: await self._ensure_containers(route) except Exception as e: print(f"[ws] Container error: {e}") await send({"type": "websocket.close", "code": 1011}) return path = scope["path"] query = scope.get("query_string", b"").decode() url = f"ws://{route.target_host}:{route.target_port}{path}" if query: url += f"?{query}" await send({"type": "websocket.accept"}) async def client_to_backend(backend_ws): try: while True: msg = await receive() if msg["type"] == "websocket.receive": if msg.get("text") is not None: await backend_ws.send(msg["text"]) elif msg.get("bytes") is not None: await backend_ws.send(msg["bytes"]) elif msg["type"] == "websocket.disconnect": break except Exception as e: print(f"[ws] client→backend error: {e}") async def backend_to_client(backend_ws): try: async for msg in backend_ws: if isinstance(msg, str): await send({"type": "websocket.send", "text": msg}) else: await send({"type": "websocket.send", "bytes": msg}) except Exception as e: print(f"[ws] backend→client error: {e}") try: async with websockets.connect(url) as backend_ws: t1 = asyncio.create_task(client_to_backend(backend_ws)) t2 = asyncio.create_task(backend_to_client(backend_ws)) _done, pending = await asyncio.wait( [t1, t2], return_when=asyncio.FIRST_COMPLETED ) for task in pending: task.cancel() try: await task except asyncio.CancelledError: pass except Exception as e: print(f"[ws] Connection error to {url}: {e}") finally: try: await send({"type": "websocket.close", "code": 1000}) except Exception: pass