Files
PY_PROXY/app.py
2025-04-16 18:55:11 +02:00

149 lines
5.4 KiB
Python

# main.py
import asyncio
from fastapi import FastAPI, Request, Response, status
from fastapi.responses import JSONResponse
import httpx
import yaml
import docker
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
from threading import Thread
from time import time, sleep
import websockets
from starlette.types import Receive, Scope, Send
from starlette.websockets import WebSocket
from starlette.requests import HTTPConnection
# --- Configuration Loading ---
class Route:
def __init__(self, path_prefix: str, target: str, container: str | None = None):
self.path_prefix = path_prefix
self.target = target
self.container = container
def load_config(path: str):
with open(path, 'r') as f:
data = yaml.safe_load(f)
return [Route(**r) for r in data.get('routes', [])]
# --- Docker Management ---
idle_containers = {}
idle_timeout = {}
idle_check_interval = 5
def idle_watcher():
client = docker.from_env()
while True:
now = time()
for container, until in list(idle_containers.items()):
if until and now > until:
try:
c = client.containers.get(container)
c.stop()
print(f"[watcher] Stopped idle container: {container}")
except Exception as e:
print(f"[watcher] Failed to stop {container}: {e}")
finally:
idle_containers.pop(container, None)
sleep(idle_check_interval)
Thread(target=idle_watcher, daemon=True).start()
# --- Proxy Middleware ---
class ReverseProxyMiddleware(BaseHTTPMiddleware):
def __init__(self, app: ASGIApp, routes: list[Route]):
super().__init__(app)
self.routes = routes
self.docker = docker.from_env()
async def dispatch(self, request: Request, call_next):
path = request.url.path
route = next((r for r in self.routes if path.startswith(r.path_prefix)), None)
if not route:
return JSONResponse(status_code=status.HTTP_404_NOT_FOUND, content={"detail": "Not Found"})
if route.container:
try:
container = self.docker.containers.get(route.container)
if container.status != 'running':
container.start()
print(f"Started container {route.container}")
timeout = idle_timeout.get(route.container)
if timeout:
idle_containers[route.container] = time() + timeout
else:
idle_containers[route.container] = None
except Exception as e:
print(f"[proxy] Failed to ensure container '{route.container}': {e}")
return JSONResponse(status_code=status.HTTP_502_BAD_GATEWAY, content={"detail": "Container Error"})
# WebSocket upgrade detection
if request.headers.get("upgrade", "").lower() == "websocket":
return await self.handle_websocket_upgrade(request, route)
new_url = route.target.rstrip('/') + path[len(route.path_prefix):]
async with httpx.AsyncClient() as client:
try:
proxied = await client.request(
method=request.method,
url=new_url,
headers=request.headers.raw,
content=await request.body(),
timeout=30.0
)
return Response(content=proxied.content, status_code=proxied.status_code, headers=proxied.headers)
except Exception as e:
print(f"[proxy] Failed proxying to {new_url}: {e}")
return JSONResponse(status_code=status.HTTP_502_BAD_GATEWAY, content={"detail": "Upstream Error"})
async def handle_websocket_upgrade(self, request: Request, route: Route):
scope: Scope = request.scope
receive: Receive = request.receive
send: Send = request._send # Unsafe, but needed for ASGI hijack
ws = WebSocket(scope, receive=receive, send=send)
await ws.accept()
target_ws_url = route.target.rstrip('/') + request.url.path[len(route.path_prefix):]
if target_ws_url.startswith("http"):
target_ws_url = target_ws_url.replace("http", "ws", 1)
try:
async with websockets.connect(target_ws_url) as backend:
async def to_backend():
try:
while True:
data = await ws.receive_text()
await backend.send(data)
except Exception:
await backend.close()
async def from_backend():
try:
while True:
data = await backend.recv()
await ws.send_text(data)
except Exception:
await ws.close()
await asyncio.gather(to_backend(), from_backend())
except Exception as e:
print(f"[ws] Proxy error: {e}")
await ws.close(code=1011)
return Response(status_code=502, content=b"WebSocket proxy error")
# --- App Setup ---
app = FastAPI()
config = load_config("config.yml")
app.add_middleware(ReverseProxyMiddleware, routes=config)
# Optional: Health check
@app.get("/health")
async def health():
return {"status": "ok"}