from __future__ import annotations import asyncio from pathlib import Path import httpx import websockets from config import ProxyRoute from docker_manager import DockerManager # Headers that must not be forwarded between proxy and backend HOP_BY_HOP = { "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade", } TEMPLATES = Path(__file__).parent / "templates" def _load_template(name: str, **replacements: str) -> bytes: text = (TEMPLATES / name).read_text(encoding="utf-8") for key, value in replacements.items(): text = text.replace("{{" + key + "}}", value) return text.encode() class ProxyMiddleware: def __init__(self, app, routes: dict[str, ProxyRoute], docker: DockerManager): self.app = app self.routes = routes self.docker = docker async def __call__(self, scope, receive, send): if scope["type"] == "http": await self._http(scope, receive, send) elif scope["type"] == "websocket": await self._websocket(scope, receive, send) else: # Pass lifespan and other event types to the inner app await self.app(scope, receive, send) # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ def _find_route(self, scope) -> ProxyRoute | None: headers = dict(scope.get("headers", [])) host = headers.get(b"host", b"").decode().split(":")[0] return self.routes.get(host) async def _ensure_containers(self, route: ProxyRoute): for name in route.containers: await self.docker.ensure_running(name, route.load_seconds) self.docker.reset_idle(name, route.timeout_seconds) # ------------------------------------------------------------------ # HTTP proxy # ------------------------------------------------------------------ async def _http(self, scope, receive, send): # Health check handled directly so it works regardless of Host header if scope["path"] == "/health": await send({"type": "http.response.start", "status": 200, "headers": [(b"content-type", b"application/json")]}) await send({"type": "http.response.body", "body": b'{"status":"ok"}'}) return route = self._find_route(scope) if not route: headers = dict(scope.get("headers", [])) host = headers.get(b"host", b"").decode().split(":")[0] body = _load_template("404.html", SERVICE=host) await send({"type": "http.response.start", "status": 404, "headers": [(b"content-type", b"text/html; charset=utf-8")]}) await send({"type": "http.response.body", "body": body}) return # If any container is not yet running, show wait page and start it # in the background. The browser will auto-refresh via the Refresh header. for name in route.containers: if not await self.docker.is_running(name): asyncio.create_task(self._ensure_containers(route)) body = _load_template("wait.html") await send({"type": "http.response.start", "status": 503, "headers": [ (b"content-type", b"text/html; charset=utf-8"), (b"refresh", b"3"), (b"content-length", str(len(body)).encode()), ]}) await send({"type": "http.response.body", "body": body}) return # All containers running – reset idle timers and proxy the request. for name in route.containers: self.docker.reset_idle(name, route.timeout_seconds) path = scope["path"] query = scope.get("query_string", b"").decode() url = f"http://{route.target_host}:{route.target_port}{path}" if query: url += f"?{query}" req_headers = [ (k, v) for k, v in scope["headers"] if k.lower().decode() not in HOP_BY_HOP ] body = b"" while True: msg = await receive() body += msg.get("body", b"") if not msg.get("more_body", False): break async with httpx.AsyncClient() as client: try: resp = await client.request( method=scope["method"], url=url, headers=req_headers, content=body, timeout=30.0, ) # Strip hop-by-hop headers + content-encoding/length (httpx # decompresses automatically, so the original values are wrong). skip = HOP_BY_HOP | {"content-encoding", "content-length"} resp_headers = [ (k.lower(), v) for k, v in resp.headers.raw if k.lower().decode() not in skip ] resp_headers.append( (b"content-length", str(len(resp.content)).encode()) ) await send({"type": "http.response.start", "status": resp.status_code, "headers": resp_headers}) await send({"type": "http.response.body", "body": resp.content}) except Exception as e: print(f"[http] Upstream error for {url}: {e}") await send({"type": "http.response.start", "status": 502, "headers": [(b"content-type", b"text/plain")]}) await send({"type": "http.response.body", "body": b"Bad Gateway"}) # ------------------------------------------------------------------ # WebSocket proxy # ------------------------------------------------------------------ async def _websocket(self, scope, receive, send): route = self._find_route(scope) if not route: # Reject before accept await send({"type": "websocket.close", "code": 4004}) return try: await self._ensure_containers(route) except Exception as e: print(f"[ws] Container error: {e}") await send({"type": "websocket.close", "code": 1011}) return path = scope["path"] query = scope.get("query_string", b"").decode() url = f"ws://{route.target_host}:{route.target_port}{path}" if query: url += f"?{query}" # Accept the client WebSocket connection await send({"type": "websocket.accept"}) async def client_to_backend(backend_ws): try: while True: msg = await receive() if msg["type"] == "websocket.receive": if msg.get("text") is not None: await backend_ws.send(msg["text"]) elif msg.get("bytes") is not None: await backend_ws.send(msg["bytes"]) elif msg["type"] == "websocket.disconnect": break except Exception as e: print(f"[ws] client→backend error: {e}") async def backend_to_client(backend_ws): try: async for msg in backend_ws: if isinstance(msg, str): await send({"type": "websocket.send", "text": msg}) else: await send({"type": "websocket.send", "bytes": msg}) except Exception as e: print(f"[ws] backend→client error: {e}") try: async with websockets.connect(url) as backend_ws: t1 = asyncio.ensure_future(client_to_backend(backend_ws)) t2 = asyncio.ensure_future(backend_to_client(backend_ws)) _done, pending = await asyncio.wait( [t1, t2], return_when=asyncio.FIRST_COMPLETED ) for task in pending: task.cancel() try: await task except asyncio.CancelledError: pass except Exception as e: print(f"[ws] Connection error to {url}: {e}") finally: try: await send({"type": "websocket.close", "code": 1000}) except Exception: pass