Files
PY_PROXY/middleware.py
Jonatan Rek f601bb3fc8 Progress
2026-03-02 12:16:00 +01:00

215 lines
8.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import asyncio
from pathlib import Path
import httpx
import websockets
from config import ProxyRoute
from docker_manager import DockerManager
# Headers that must not be forwarded between proxy and backend
HOP_BY_HOP = {
"connection", "keep-alive", "proxy-authenticate", "proxy-authorization",
"te", "trailers", "transfer-encoding", "upgrade",
}
TEMPLATES = Path(__file__).parent / "templates"
def _load_template(name: str, **replacements: str) -> bytes:
text = (TEMPLATES / name).read_text(encoding="utf-8")
for key, value in replacements.items():
text = text.replace("{{" + key + "}}", value)
return text.encode()
class ProxyMiddleware:
def __init__(self, app, routes: dict[str, ProxyRoute], docker: DockerManager):
self.app = app
self.routes = routes
self.docker = docker
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
await self._http(scope, receive, send)
elif scope["type"] == "websocket":
await self._websocket(scope, receive, send)
else:
# Pass lifespan and other event types to the inner app
await self.app(scope, receive, send)
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _find_route(self, scope) -> ProxyRoute | None:
headers = dict(scope.get("headers", []))
host = headers.get(b"host", b"").decode().split(":")[0]
return self.routes.get(host)
async def _ensure_containers(self, route: ProxyRoute):
for name in route.containers:
await self.docker.ensure_running(name, route.load_seconds)
self.docker.reset_idle(name, route.timeout_seconds)
# ------------------------------------------------------------------
# HTTP proxy
# ------------------------------------------------------------------
async def _http(self, scope, receive, send):
# Health check handled directly so it works regardless of Host header
if scope["path"] == "/health":
await send({"type": "http.response.start", "status": 200,
"headers": [(b"content-type", b"application/json")]})
await send({"type": "http.response.body", "body": b'{"status":"ok"}'})
return
route = self._find_route(scope)
if not route:
headers = dict(scope.get("headers", []))
host = headers.get(b"host", b"").decode().split(":")[0]
body = _load_template("404.html", SERVICE=host)
await send({"type": "http.response.start", "status": 404,
"headers": [(b"content-type", b"text/html; charset=utf-8")]})
await send({"type": "http.response.body", "body": body})
return
# If any container is not yet running, show wait page and start it
# in the background. The browser will auto-refresh via the Refresh header.
for name in route.containers:
if not await self.docker.is_running(name):
asyncio.create_task(self._ensure_containers(route))
body = _load_template("wait.html")
await send({"type": "http.response.start", "status": 503,
"headers": [
(b"content-type", b"text/html; charset=utf-8"),
(b"refresh", b"3"),
(b"content-length", str(len(body)).encode()),
]})
await send({"type": "http.response.body", "body": body})
return
# All containers running reset idle timers and proxy the request.
for name in route.containers:
self.docker.reset_idle(name, route.timeout_seconds)
path = scope["path"]
query = scope.get("query_string", b"").decode()
url = f"http://{route.target_host}:{route.target_port}{path}"
if query:
url += f"?{query}"
req_headers = [
(k, v) for k, v in scope["headers"]
if k.lower().decode() not in HOP_BY_HOP
]
body = b""
while True:
msg = await receive()
body += msg.get("body", b"")
if not msg.get("more_body", False):
break
async with httpx.AsyncClient() as client:
try:
resp = await client.request(
method=scope["method"],
url=url,
headers=req_headers,
content=body,
timeout=30.0,
)
# Strip hop-by-hop headers + content-encoding/length (httpx
# decompresses automatically, so the original values are wrong).
skip = HOP_BY_HOP | {"content-encoding", "content-length"}
resp_headers = [
(k.lower(), v) for k, v in resp.headers.raw
if k.lower().decode() not in skip
]
resp_headers.append(
(b"content-length", str(len(resp.content)).encode())
)
await send({"type": "http.response.start",
"status": resp.status_code, "headers": resp_headers})
await send({"type": "http.response.body", "body": resp.content})
except Exception as e:
print(f"[http] Upstream error for {url}: {e}")
await send({"type": "http.response.start", "status": 502,
"headers": [(b"content-type", b"text/plain")]})
await send({"type": "http.response.body", "body": b"Bad Gateway"})
# ------------------------------------------------------------------
# WebSocket proxy
# ------------------------------------------------------------------
async def _websocket(self, scope, receive, send):
route = self._find_route(scope)
if not route:
# Reject before accept
await send({"type": "websocket.close", "code": 4004})
return
try:
await self._ensure_containers(route)
except Exception as e:
print(f"[ws] Container error: {e}")
await send({"type": "websocket.close", "code": 1011})
return
path = scope["path"]
query = scope.get("query_string", b"").decode()
url = f"ws://{route.target_host}:{route.target_port}{path}"
if query:
url += f"?{query}"
# Accept the client WebSocket connection
await send({"type": "websocket.accept"})
async def client_to_backend(backend_ws):
try:
while True:
msg = await receive()
if msg["type"] == "websocket.receive":
if msg.get("text") is not None:
await backend_ws.send(msg["text"])
elif msg.get("bytes") is not None:
await backend_ws.send(msg["bytes"])
elif msg["type"] == "websocket.disconnect":
break
except Exception as e:
print(f"[ws] client→backend error: {e}")
async def backend_to_client(backend_ws):
try:
async for msg in backend_ws:
if isinstance(msg, str):
await send({"type": "websocket.send", "text": msg})
else:
await send({"type": "websocket.send", "bytes": msg})
except Exception as e:
print(f"[ws] backend→client error: {e}")
try:
async with websockets.connect(url) as backend_ws:
t1 = asyncio.ensure_future(client_to_backend(backend_ws))
t2 = asyncio.ensure_future(backend_to_client(backend_ws))
_done, pending = await asyncio.wait(
[t1, t2], return_when=asyncio.FIRST_COMPLETED
)
for task in pending:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
except Exception as e:
print(f"[ws] Connection error to {url}: {e}")
finally:
try:
await send({"type": "websocket.close", "code": 1000})
except Exception:
pass