Files
PY_PROXY/middleware.py
Jonatan Rek db0853886f progress
2026-03-02 14:51:15 +01:00

216 lines
8.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import asyncio
from pathlib import Path
import httpx
import websockets
from config import ProxyRoute
from docker_manager import DockerManager
# Hop-by-hop headers stored as bytes avoids .decode() on every header (RFC 2616 §13.5.1)
_HOP_BY_HOP = {
b"connection", b"keep-alive", b"proxy-authenticate", b"proxy-authorization",
b"te", b"trailers", b"transfer-encoding", b"upgrade",
}
_TEMPLATES_DIR = Path(__file__).parent / "templates"
_template_cache: dict[str, str] = {} # filename -> raw text, cached on first disk read
def _load_template(name: str, **replacements: str) -> bytes:
if name not in _template_cache:
_template_cache[name] = (_TEMPLATES_DIR / name).read_text(encoding="utf-8")
text = _template_cache[name]
for key, value in replacements.items():
text = text.replace("{{" + key + "}}", value)
return text.encode()
class ProxyMiddleware:
def __init__(self, app, routes: dict[str, ProxyRoute], docker: DockerManager,
http_client: httpx.AsyncClient):
self.app = app
self.routes = routes
self.docker = docker
self._client = http_client
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
await self._http(scope, receive, send)
elif scope["type"] == "websocket":
await self._websocket(scope, receive, send)
else:
await self.app(scope, receive, send)
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _find_route(self, scope) -> ProxyRoute | None:
headers = dict(scope.get("headers", []))
host = headers.get(b"host", b"").decode().split(":")[0]
return self.routes.get(host)
async def _ensure_containers(self, route: ProxyRoute):
for name in route.containers:
await self.docker.ensure_running(name, route.load_seconds)
self.docker.reset_idle(name, route.timeout_seconds)
# ------------------------------------------------------------------
# HTTP proxy
# ------------------------------------------------------------------
async def _http(self, scope, receive, send):
if scope["path"] == "/health":
await send({"type": "http.response.start", "status": 200,
"headers": [(b"content-type", b"application/json")]})
await send({"type": "http.response.body", "body": b'{"status":"ok"}'})
return
route = self._find_route(scope)
if not route:
headers = dict(scope.get("headers", []))
host = headers.get(b"host", b"").decode().split(":")[0]
body = _load_template("404.html", SERVICE=host)
await send({"type": "http.response.start", "status": 404,
"headers": [(b"content-type", b"text/html; charset=utf-8"),
(b"content-length", str(len(body)).encode())]})
await send({"type": "http.response.body", "body": body})
return
# If any container is not yet running, show wait page and start it
# in the background. The browser will auto-refresh via the Refresh header.
for name in route.containers:
if not await self.docker.is_running(name):
asyncio.create_task(self._ensure_containers(route))
body = _load_template("wait.html")
await send({"type": "http.response.start", "status": 503,
"headers": [
(b"content-type", b"text/html; charset=utf-8"),
(b"refresh", b"3"),
(b"content-length", str(len(body)).encode()),
]})
await send({"type": "http.response.body", "body": body})
return
# All containers running reset idle timers and proxy the request.
for name in route.containers:
self.docker.reset_idle(name, route.timeout_seconds)
path = scope["path"]
query = scope.get("query_string", b"").decode()
url = f"http://{route.target_host}:{route.target_port}{path}"
if query:
url += f"?{query}"
# bytes comparison no .decode() needed per header
req_headers = [
(k, v) for k, v in scope["headers"]
if k.lower() not in _HOP_BY_HOP
]
body = b""
while True:
msg = await receive()
body += msg.get("body", b"")
if not msg.get("more_body", False):
break
try:
resp = await self._client.request(
method=scope["method"],
url=url,
headers=req_headers,
content=body,
)
# Strip hop-by-hop + content-encoding/length (httpx decompresses
# automatically so the original values would be wrong).
skip = _HOP_BY_HOP | {b"content-encoding", b"content-length"}
resp_headers = [
(k.lower(), v) for k, v in resp.headers.raw
if k.lower() not in skip
]
resp_headers.append(
(b"content-length", str(len(resp.content)).encode())
)
await send({"type": "http.response.start",
"status": resp.status_code, "headers": resp_headers})
await send({"type": "http.response.body", "body": resp.content})
except Exception as e:
print(f"[http] Upstream error for {url}: {e}")
await send({"type": "http.response.start", "status": 502,
"headers": [(b"content-type", b"text/plain")]})
await send({"type": "http.response.body", "body": b"Bad Gateway"})
# ------------------------------------------------------------------
# WebSocket proxy
# ------------------------------------------------------------------
async def _websocket(self, scope, receive, send):
route = self._find_route(scope)
if not route:
await send({"type": "websocket.close", "code": 4004})
return
try:
await self._ensure_containers(route)
except Exception as e:
print(f"[ws] Container error: {e}")
await send({"type": "websocket.close", "code": 1011})
return
path = scope["path"]
query = scope.get("query_string", b"").decode()
url = f"ws://{route.target_host}:{route.target_port}{path}"
if query:
url += f"?{query}"
await send({"type": "websocket.accept"})
async def client_to_backend(backend_ws):
try:
while True:
msg = await receive()
if msg["type"] == "websocket.receive":
if msg.get("text") is not None:
await backend_ws.send(msg["text"])
elif msg.get("bytes") is not None:
await backend_ws.send(msg["bytes"])
elif msg["type"] == "websocket.disconnect":
break
except Exception as e:
print(f"[ws] client→backend error: {e}")
async def backend_to_client(backend_ws):
try:
async for msg in backend_ws:
if isinstance(msg, str):
await send({"type": "websocket.send", "text": msg})
else:
await send({"type": "websocket.send", "bytes": msg})
except Exception as e:
print(f"[ws] backend→client error: {e}")
try:
async with websockets.connect(url) as backend_ws:
t1 = asyncio.create_task(client_to_backend(backend_ws))
t2 = asyncio.create_task(backend_to_client(backend_ws))
_done, pending = await asyncio.wait(
[t1, t2], return_when=asyncio.FIRST_COMPLETED
)
for task in pending:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
except Exception as e:
print(f"[ws] Connection error to {url}: {e}")
finally:
try:
await send({"type": "websocket.close", "code": 1000})
except Exception:
pass