215 lines
8.4 KiB
Python
215 lines
8.4 KiB
Python
from __future__ import annotations
|
||
|
||
import asyncio
|
||
from pathlib import Path
|
||
|
||
import httpx
|
||
import websockets
|
||
|
||
from config import ProxyRoute
|
||
from docker_manager import DockerManager
|
||
|
||
# Headers that must not be forwarded between proxy and backend
|
||
HOP_BY_HOP = {
|
||
"connection", "keep-alive", "proxy-authenticate", "proxy-authorization",
|
||
"te", "trailers", "transfer-encoding", "upgrade",
|
||
}
|
||
|
||
TEMPLATES = Path(__file__).parent / "templates"
|
||
|
||
|
||
def _load_template(name: str, **replacements: str) -> bytes:
|
||
text = (TEMPLATES / name).read_text(encoding="utf-8")
|
||
for key, value in replacements.items():
|
||
text = text.replace("{{" + key + "}}", value)
|
||
return text.encode()
|
||
|
||
|
||
class ProxyMiddleware:
|
||
def __init__(self, app, routes: dict[str, ProxyRoute], docker: DockerManager):
|
||
self.app = app
|
||
self.routes = routes
|
||
self.docker = docker
|
||
|
||
async def __call__(self, scope, receive, send):
|
||
if scope["type"] == "http":
|
||
await self._http(scope, receive, send)
|
||
elif scope["type"] == "websocket":
|
||
await self._websocket(scope, receive, send)
|
||
else:
|
||
# Pass lifespan and other event types to the inner app
|
||
await self.app(scope, receive, send)
|
||
|
||
# ------------------------------------------------------------------
|
||
# Helpers
|
||
# ------------------------------------------------------------------
|
||
|
||
def _find_route(self, scope) -> ProxyRoute | None:
|
||
headers = dict(scope.get("headers", []))
|
||
host = headers.get(b"host", b"").decode().split(":")[0]
|
||
return self.routes.get(host)
|
||
|
||
async def _ensure_containers(self, route: ProxyRoute):
|
||
for name in route.containers:
|
||
await self.docker.ensure_running(name, route.load_seconds)
|
||
self.docker.reset_idle(name, route.timeout_seconds)
|
||
|
||
# ------------------------------------------------------------------
|
||
# HTTP proxy
|
||
# ------------------------------------------------------------------
|
||
|
||
async def _http(self, scope, receive, send):
|
||
# Health check handled directly so it works regardless of Host header
|
||
if scope["path"] == "/health":
|
||
await send({"type": "http.response.start", "status": 200,
|
||
"headers": [(b"content-type", b"application/json")]})
|
||
await send({"type": "http.response.body", "body": b'{"status":"ok"}'})
|
||
return
|
||
|
||
route = self._find_route(scope)
|
||
if not route:
|
||
headers = dict(scope.get("headers", []))
|
||
host = headers.get(b"host", b"").decode().split(":")[0]
|
||
body = _load_template("404.html", SERVICE=host)
|
||
await send({"type": "http.response.start", "status": 404,
|
||
"headers": [(b"content-type", b"text/html; charset=utf-8")]})
|
||
await send({"type": "http.response.body", "body": body})
|
||
return
|
||
|
||
# If any container is not yet running, show wait page and start it
|
||
# in the background. The browser will auto-refresh via the Refresh header.
|
||
for name in route.containers:
|
||
if not await self.docker.is_running(name):
|
||
asyncio.create_task(self._ensure_containers(route))
|
||
body = _load_template("wait.html")
|
||
await send({"type": "http.response.start", "status": 503,
|
||
"headers": [
|
||
(b"content-type", b"text/html; charset=utf-8"),
|
||
(b"refresh", b"3"),
|
||
(b"content-length", str(len(body)).encode()),
|
||
]})
|
||
await send({"type": "http.response.body", "body": body})
|
||
return
|
||
|
||
# All containers running – reset idle timers and proxy the request.
|
||
for name in route.containers:
|
||
self.docker.reset_idle(name, route.timeout_seconds)
|
||
|
||
path = scope["path"]
|
||
query = scope.get("query_string", b"").decode()
|
||
url = f"http://{route.target_host}:{route.target_port}{path}"
|
||
if query:
|
||
url += f"?{query}"
|
||
|
||
req_headers = [
|
||
(k, v) for k, v in scope["headers"]
|
||
if k.lower().decode() not in HOP_BY_HOP
|
||
]
|
||
|
||
body = b""
|
||
while True:
|
||
msg = await receive()
|
||
body += msg.get("body", b"")
|
||
if not msg.get("more_body", False):
|
||
break
|
||
|
||
async with httpx.AsyncClient() as client:
|
||
try:
|
||
resp = await client.request(
|
||
method=scope["method"],
|
||
url=url,
|
||
headers=req_headers,
|
||
content=body,
|
||
timeout=30.0,
|
||
)
|
||
# Strip hop-by-hop headers + content-encoding/length (httpx
|
||
# decompresses automatically, so the original values are wrong).
|
||
skip = HOP_BY_HOP | {"content-encoding", "content-length"}
|
||
resp_headers = [
|
||
(k.lower(), v) for k, v in resp.headers.raw
|
||
if k.lower().decode() not in skip
|
||
]
|
||
resp_headers.append(
|
||
(b"content-length", str(len(resp.content)).encode())
|
||
)
|
||
await send({"type": "http.response.start",
|
||
"status": resp.status_code, "headers": resp_headers})
|
||
await send({"type": "http.response.body", "body": resp.content})
|
||
except Exception as e:
|
||
print(f"[http] Upstream error for {url}: {e}")
|
||
await send({"type": "http.response.start", "status": 502,
|
||
"headers": [(b"content-type", b"text/plain")]})
|
||
await send({"type": "http.response.body", "body": b"Bad Gateway"})
|
||
|
||
# ------------------------------------------------------------------
|
||
# WebSocket proxy
|
||
# ------------------------------------------------------------------
|
||
|
||
async def _websocket(self, scope, receive, send):
|
||
route = self._find_route(scope)
|
||
if not route:
|
||
# Reject before accept
|
||
await send({"type": "websocket.close", "code": 4004})
|
||
return
|
||
|
||
try:
|
||
await self._ensure_containers(route)
|
||
except Exception as e:
|
||
print(f"[ws] Container error: {e}")
|
||
await send({"type": "websocket.close", "code": 1011})
|
||
return
|
||
|
||
path = scope["path"]
|
||
query = scope.get("query_string", b"").decode()
|
||
url = f"ws://{route.target_host}:{route.target_port}{path}"
|
||
if query:
|
||
url += f"?{query}"
|
||
|
||
# Accept the client WebSocket connection
|
||
await send({"type": "websocket.accept"})
|
||
|
||
async def client_to_backend(backend_ws):
|
||
try:
|
||
while True:
|
||
msg = await receive()
|
||
if msg["type"] == "websocket.receive":
|
||
if msg.get("text") is not None:
|
||
await backend_ws.send(msg["text"])
|
||
elif msg.get("bytes") is not None:
|
||
await backend_ws.send(msg["bytes"])
|
||
elif msg["type"] == "websocket.disconnect":
|
||
break
|
||
except Exception as e:
|
||
print(f"[ws] client→backend error: {e}")
|
||
|
||
async def backend_to_client(backend_ws):
|
||
try:
|
||
async for msg in backend_ws:
|
||
if isinstance(msg, str):
|
||
await send({"type": "websocket.send", "text": msg})
|
||
else:
|
||
await send({"type": "websocket.send", "bytes": msg})
|
||
except Exception as e:
|
||
print(f"[ws] backend→client error: {e}")
|
||
|
||
try:
|
||
async with websockets.connect(url) as backend_ws:
|
||
t1 = asyncio.ensure_future(client_to_backend(backend_ws))
|
||
t2 = asyncio.ensure_future(backend_to_client(backend_ws))
|
||
_done, pending = await asyncio.wait(
|
||
[t1, t2], return_when=asyncio.FIRST_COMPLETED
|
||
)
|
||
for task in pending:
|
||
task.cancel()
|
||
try:
|
||
await task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
except Exception as e:
|
||
print(f"[ws] Connection error to {url}: {e}")
|
||
finally:
|
||
try:
|
||
await send({"type": "websocket.close", "code": 1000})
|
||
except Exception:
|
||
pass
|