progress
This commit is contained in:
@@ -9,27 +9,32 @@ import websockets
|
||||
from config import ProxyRoute
|
||||
from docker_manager import DockerManager
|
||||
|
||||
# Headers that must not be forwarded between proxy and backend
|
||||
HOP_BY_HOP = {
|
||||
"connection", "keep-alive", "proxy-authenticate", "proxy-authorization",
|
||||
"te", "trailers", "transfer-encoding", "upgrade",
|
||||
# Hop-by-hop headers stored as bytes – avoids .decode() on every header (RFC 2616 §13.5.1)
|
||||
_HOP_BY_HOP = {
|
||||
b"connection", b"keep-alive", b"proxy-authenticate", b"proxy-authorization",
|
||||
b"te", b"trailers", b"transfer-encoding", b"upgrade",
|
||||
}
|
||||
|
||||
TEMPLATES = Path(__file__).parent / "templates"
|
||||
_TEMPLATES_DIR = Path(__file__).parent / "templates"
|
||||
_template_cache: dict[str, str] = {} # filename -> raw text, cached on first disk read
|
||||
|
||||
|
||||
def _load_template(name: str, **replacements: str) -> bytes:
|
||||
text = (TEMPLATES / name).read_text(encoding="utf-8")
|
||||
if name not in _template_cache:
|
||||
_template_cache[name] = (_TEMPLATES_DIR / name).read_text(encoding="utf-8")
|
||||
text = _template_cache[name]
|
||||
for key, value in replacements.items():
|
||||
text = text.replace("{{" + key + "}}", value)
|
||||
return text.encode()
|
||||
|
||||
|
||||
class ProxyMiddleware:
|
||||
def __init__(self, app, routes: dict[str, ProxyRoute], docker: DockerManager):
|
||||
def __init__(self, app, routes: dict[str, ProxyRoute], docker: DockerManager,
|
||||
http_client: httpx.AsyncClient):
|
||||
self.app = app
|
||||
self.routes = routes
|
||||
self.docker = docker
|
||||
self._client = http_client
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope["type"] == "http":
|
||||
@@ -37,7 +42,6 @@ class ProxyMiddleware:
|
||||
elif scope["type"] == "websocket":
|
||||
await self._websocket(scope, receive, send)
|
||||
else:
|
||||
# Pass lifespan and other event types to the inner app
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@@ -59,7 +63,6 @@ class ProxyMiddleware:
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _http(self, scope, receive, send):
|
||||
# Health check handled directly so it works regardless of Host header
|
||||
if scope["path"] == "/health":
|
||||
await send({"type": "http.response.start", "status": 200,
|
||||
"headers": [(b"content-type", b"application/json")]})
|
||||
@@ -72,7 +75,8 @@ class ProxyMiddleware:
|
||||
host = headers.get(b"host", b"").decode().split(":")[0]
|
||||
body = _load_template("404.html", SERVICE=host)
|
||||
await send({"type": "http.response.start", "status": 404,
|
||||
"headers": [(b"content-type", b"text/html; charset=utf-8")]})
|
||||
"headers": [(b"content-type", b"text/html; charset=utf-8"),
|
||||
(b"content-length", str(len(body)).encode())]})
|
||||
await send({"type": "http.response.body", "body": body})
|
||||
return
|
||||
|
||||
@@ -101,9 +105,10 @@ class ProxyMiddleware:
|
||||
if query:
|
||||
url += f"?{query}"
|
||||
|
||||
# bytes comparison – no .decode() needed per header
|
||||
req_headers = [
|
||||
(k, v) for k, v in scope["headers"]
|
||||
if k.lower().decode() not in HOP_BY_HOP
|
||||
if k.lower() not in _HOP_BY_HOP
|
||||
]
|
||||
|
||||
body = b""
|
||||
@@ -113,33 +118,31 @@ class ProxyMiddleware:
|
||||
if not msg.get("more_body", False):
|
||||
break
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
resp = await client.request(
|
||||
method=scope["method"],
|
||||
url=url,
|
||||
headers=req_headers,
|
||||
content=body,
|
||||
timeout=30.0,
|
||||
)
|
||||
# Strip hop-by-hop headers + content-encoding/length (httpx
|
||||
# decompresses automatically, so the original values are wrong).
|
||||
skip = HOP_BY_HOP | {"content-encoding", "content-length"}
|
||||
resp_headers = [
|
||||
(k.lower(), v) for k, v in resp.headers.raw
|
||||
if k.lower().decode() not in skip
|
||||
]
|
||||
resp_headers.append(
|
||||
(b"content-length", str(len(resp.content)).encode())
|
||||
)
|
||||
await send({"type": "http.response.start",
|
||||
"status": resp.status_code, "headers": resp_headers})
|
||||
await send({"type": "http.response.body", "body": resp.content})
|
||||
except Exception as e:
|
||||
print(f"[http] Upstream error for {url}: {e}")
|
||||
await send({"type": "http.response.start", "status": 502,
|
||||
"headers": [(b"content-type", b"text/plain")]})
|
||||
await send({"type": "http.response.body", "body": b"Bad Gateway"})
|
||||
try:
|
||||
resp = await self._client.request(
|
||||
method=scope["method"],
|
||||
url=url,
|
||||
headers=req_headers,
|
||||
content=body,
|
||||
)
|
||||
# Strip hop-by-hop + content-encoding/length (httpx decompresses
|
||||
# automatically so the original values would be wrong).
|
||||
skip = _HOP_BY_HOP | {b"content-encoding", b"content-length"}
|
||||
resp_headers = [
|
||||
(k.lower(), v) for k, v in resp.headers.raw
|
||||
if k.lower() not in skip
|
||||
]
|
||||
resp_headers.append(
|
||||
(b"content-length", str(len(resp.content)).encode())
|
||||
)
|
||||
await send({"type": "http.response.start",
|
||||
"status": resp.status_code, "headers": resp_headers})
|
||||
await send({"type": "http.response.body", "body": resp.content})
|
||||
except Exception as e:
|
||||
print(f"[http] Upstream error for {url}: {e}")
|
||||
await send({"type": "http.response.start", "status": 502,
|
||||
"headers": [(b"content-type", b"text/plain")]})
|
||||
await send({"type": "http.response.body", "body": b"Bad Gateway"})
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# WebSocket proxy
|
||||
@@ -148,7 +151,6 @@ class ProxyMiddleware:
|
||||
async def _websocket(self, scope, receive, send):
|
||||
route = self._find_route(scope)
|
||||
if not route:
|
||||
# Reject before accept
|
||||
await send({"type": "websocket.close", "code": 4004})
|
||||
return
|
||||
|
||||
@@ -165,7 +167,6 @@ class ProxyMiddleware:
|
||||
if query:
|
||||
url += f"?{query}"
|
||||
|
||||
# Accept the client WebSocket connection
|
||||
await send({"type": "websocket.accept"})
|
||||
|
||||
async def client_to_backend(backend_ws):
|
||||
@@ -194,8 +195,8 @@ class ProxyMiddleware:
|
||||
|
||||
try:
|
||||
async with websockets.connect(url) as backend_ws:
|
||||
t1 = asyncio.ensure_future(client_to_backend(backend_ws))
|
||||
t2 = asyncio.ensure_future(backend_to_client(backend_ws))
|
||||
t1 = asyncio.create_task(client_to_backend(backend_ws))
|
||||
t2 = asyncio.create_task(backend_to_client(backend_ws))
|
||||
_done, pending = await asyncio.wait(
|
||||
[t1, t2], return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user