From d1f965effe28f393b1b7fca1079eaf5fe3d97bec Mon Sep 17 00:00:00 2001 From: JonatanRek Date: Wed, 16 Apr 2025 18:55:11 +0200 Subject: [PATCH] Possible fixes needs testing --- .gitignore | 1 + app.py | 370 ++++++++++++++++------------------------------- requirements.txt | 5 +- tests.py | 13 -- 4 files changed, 126 insertions(+), 263 deletions(-) create mode 100644 .gitignore delete mode 100644 tests.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..96403d3 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +__pycache__/* diff --git a/app.py b/app.py index a1bcabc..1c83e57 100644 --- a/app.py +++ b/app.py @@ -1,276 +1,148 @@ -import http.server -import http.client +# main.py +import asyncio +from fastapi import FastAPI, Request, Response, status +from fastapi.responses import JSONResponse +import httpx import yaml import docker -import threading -import time -import os -import asyncio +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.types import ASGIApp +from threading import Thread +from time import time, sleep import websockets -import hashlib -import base64 -from datetime import datetime, timezone -from socketserver import ThreadingMixIn -from websockets.server import WebSocketServerProtocol +from starlette.types import Receive, Scope, Send +from starlette.websockets import WebSocket +from starlette.requests import HTTPConnection -# Define the target server to proxy requests to -class ProxyHandler(http.server.BaseHTTPRequestHandler): - def __init__(self, configuration, docker_client): - global activity - self.configuration = configuration - self.docker_client = docker_client - def __call__(self, *args, **kwargs): - """Handle a request.""" - super().__init__(*args, **kwargs) +# --- Configuration Loading --- +class Route: + def __init__(self, path_prefix: str, target: str, container: str | None = None): + self.path_prefix = path_prefix + self.target = target + self.container = container - def log_message(self, format, *args): - pass +def load_config(path: str): + with open(path, 'r') as f: + data = yaml.safe_load(f) + return [Route(**r) for r in data.get('routes', [])] - def finish(self,*args,**kw): - try: - if not self.wfile.closed: - self.wfile.flush() - self.wfile.close() - except socket.error: - pass - self.rfile.close() - def do_GET(self): - self.handle_request('GET') +# --- Docker Management --- +idle_containers = {} +idle_timeout = {} +idle_check_interval = 5 - def do_POST(self): - self.handle_request('POST') - - def do_PUT(self): - self.handle_request('PUT') - - def do_DELETE(self): - self.handle_request('DELETE') - - def do_HEAD(self): - self.handle_request('HEAD') - - def handle_request(self, method): - #print(self.headers.get('Host').split(":")[0]) - parsed_request_host = self.headers.get('Host') - - if (':' in parsed_request_host): - parsed_request_host = parsed_request_host.split(":")[0] - - proxy_host_configuration = next(filter(lambda host: host['domain'] == parsed_request_host, self.configuration['proxy_hosts'])) - - starting = False - for container in proxy_host_configuration['containers']: - container_objects = self.docker_client.containers.list(all=True, filters = { 'name' : container['container_name'] }) - if (container_objects == []): - self.send_404(proxy_host_configuration['domain']) - return - - container_object = container_objects[0] - if (container_object.status != 'running'): - print("starting container: {0}".format(container['container_name'])) - container_object.start() - starting = True - - if (starting == True): - self.send_loading(proxy_host_configuration['proxy_load_seconds'], proxy_host_configuration['domain']) - return - - activity[proxy_host_configuration['domain']] = datetime.now(timezone.utc) - - # Check if this is a WebSocket request - if self.headers.get("Upgrade", "").lower() == "websocket": - print("Request is WS connecting to {0}".format(container['container_name'])) - print("Request is WS connecting to {0}:{1}".format(proxy_host_configuration['proxy_host'],proxy_host_configuration['proxy_port'])) - activity[proxy_host_configuration['domain']] = True - self.upgrade_to_websocket(proxy_host_configuration['proxy_host'], proxy_host_configuration['proxy_port']) - return - - # Open a connection to the target server - conn = http.client.HTTPConnection(proxy_host_configuration['proxy_host'], proxy_host_configuration['proxy_port']) - conn.request(method, self.path, headers=self.headers) - response = conn.getresponse() - - self.send_response(response.status) - - self.send_header('host', proxy_host_configuration['proxy_host']) - for header, value in response.getheaders(): - self.send_header(header, value) - - self.end_headers() - self.wfile.write(response.read()) - - conn.close() - - def send_404(self, service_name): - self.send_response(404) - self.send_header('Content-Type', 'text/html; charset=utf-8') - self.send_header('Cache-Control', 'no-cache, no-store, must-revalidate') - self.send_header('Pragma', 'no-cache') - self.send_header('Expires', '0') - self.end_headers() - - with open(os.path.dirname(os.path.realpath(__file__)) + '/templates/404.html', 'r') as file: - html = file.read() - html = html.replace("{{SERVICE}}", service_name) - self.wfile.write(bytes(html,"utf-8")) - - self.wfile.flush() - - def send_loading(self, wait_time, service_name): - self.send_response(201) - self.send_header('Content-Type', 'text/html; charset=utf-8') - self.send_header('Cache-Control', 'no-cache, no-store, must-revalidate') - self.send_header('Pragma', 'no-cache') - self.send_header('Expires', '0') - self.send_header('refresh', wait_time) - self.end_headers() - - with open(os.path.dirname(os.path.realpath(__file__)) + '/templates/wait.html', 'r') as file: - html = file.read() - self.wfile.write(bytes(html,"utf-8")) - - #self.wfile.write(bytes("starting service: {0} waiting for {1}s".format(self.headers.get('Host').split(":")[0], proxy_host_configuration['proxy_timeout_seconds']),"utf-8")) - #self.wfile.write(bytes("\nlast started at: {0} ".format(activity[proxy_host_configuration['domain']]),"utf-8")) - - self.wfile.flush() - - async def websocket_proxy(self, target_host, target_port): - server_ws = None - try: - client_connection = self.connection - - # Establish server connection to backend - server_ws = await websockets.connect(f"ws://{target_host}:{target_port}") - print("connected") - # Bridge function to handle message forwarding - async def bridge_websockets(): +def idle_watcher(): + client = docker.from_env() + while True: + now = time() + for container, until in list(idle_containers.items()): + if until and now > until: try: - while True: - # Wait for a message from the client - client_message = await client_connection.recv() - print(f">: {client_message}") - # Send it to the server - await server_ws.send(client_message) - # Wait for a message from the server - server_message = await server_ws.recv() - # Send it to the client - await client_connection.send(server_message) - - except websockets.exceptions.ConnectionClosed as e: - print(f"WebSocket connection closed: {e}") + c = client.containers.get(container) + c.stop() + print(f"[watcher] Stopped idle container: {container}") except Exception as e: - print(f"Error during WebSocket communication: {e}") + print(f"[watcher] Failed to stop {container}: {e}") + finally: + idle_containers.pop(container, None) + sleep(idle_check_interval) - # Run the bridge coroutine - await bridge_websockets() - - except Exception as e: - print(f"WebSocket proxy encountered an error: {e}") - finally: - if server_ws: - await server_ws.close() - - def upgrade_to_websocket(self, target_host, target_port): - """ - Handles WebSocket upgrade requests and spawns an asyncio WebSocket proxy. - """ - key = self.headers['Sec-WebSocket-Key'] - accept_val = base64.b64encode(hashlib.sha1((key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode('utf-8')).digest()).decode('utf-8') +Thread(target=idle_watcher, daemon=True).start() - self.send_response(101) # Switching Protocols - self.send_header("Upgrade", "websocket") - self.send_header("Connection", "Upgrade") - self.send_header("Sec-WebSocket-Accept", accept_val) - self.end_headers() +# --- Proxy Middleware --- +class ReverseProxyMiddleware(BaseHTTPMiddleware): + def __init__(self, app: ASGIApp, routes: list[Route]): + super().__init__(app) + self.routes = routes + self.docker = docker.from_env() - # Upgrade the connection to a WebSocket connection - self.websocket = WebSocketServerProtocol() - self.websocket.connection_made(self.connection) - self.websocket.connection_open() + async def dispatch(self, request: Request, call_next): + path = request.url.path + route = next((r for r in self.routes if path.startswith(r.path_prefix)), None) - loop = asyncio.new_event_loop() - threading.Thread(target=loop.run_until_complete, args=(self.websocket_proxy(target_host, target_port),)).start() + if not route: + return JSONResponse(status_code=status.HTTP_404_NOT_FOUND, content={"detail": "Not Found"}) -class ThreadedHTTPServer(ThreadingMixIn, http.server.HTTPServer): - """Handle requests in a separate thread.""" + if route.container: + try: + container = self.docker.containers.get(route.container) + if container.status != 'running': + container.start() + print(f"Started container {route.container}") + timeout = idle_timeout.get(route.container) + if timeout: + idle_containers[route.container] = time() + timeout + else: + idle_containers[route.container] = None + except Exception as e: + print(f"[proxy] Failed to ensure container '{route.container}': {e}") + return JSONResponse(status_code=status.HTTP_502_BAD_GATEWAY, content={"detail": "Container Error"}) -class BackgroundTasks(threading.Thread): - def __init__(self, configuration, docker_client): - super(BackgroundTasks, self).__init__() - self.configuration = configuration - self.docker_client = docker_client + # WebSocket upgrade detection + if request.headers.get("upgrade", "").lower() == "websocket": + return await self.handle_websocket_upgrade(request, route) - def run(self,*args,**kwargs): - global activity - while True: - sleep_time = 900 - for apps in self.configuration['proxy_hosts']: - if(sleep_time > apps['proxy_timeout_seconds']): - sleep_time = apps['proxy_timeout_seconds'] + new_url = route.target.rstrip('/') + path[len(route.path_prefix):] + async with httpx.AsyncClient() as client: + try: + proxied = await client.request( + method=request.method, + url=new_url, + headers=request.headers.raw, + content=await request.body(), + timeout=30.0 + ) + return Response(content=proxied.content, status_code=proxied.status_code, headers=proxied.headers) + except Exception as e: + print(f"[proxy] Failed proxying to {new_url}: {e}") + return JSONResponse(status_code=status.HTTP_502_BAD_GATEWAY, content={"detail": "Upstream Error"}) - for container in apps['containers']: + async def handle_websocket_upgrade(self, request: Request, route: Route): + scope: Scope = request.scope + receive: Receive = request.receive + send: Send = request._send # Unsafe, but needed for ASGI hijack + ws = WebSocket(scope, receive=receive, send=send) + await ws.accept() + + target_ws_url = route.target.rstrip('/') + request.url.path[len(route.path_prefix):] + if target_ws_url.startswith("http"): + target_ws_url = target_ws_url.replace("http", "ws", 1) + + try: + async with websockets.connect(target_ws_url) as backend: + async def to_backend(): try: - container_object = self.docker_client.containers.get(container['container_name']) - if (container_object.status == 'running'): + while True: + data = await ws.receive_text() + await backend.send(data) + except Exception: + await backend.close() - dt = datetime.now(timezone.utc) - if (apps['domain'] in activity): - dt = activity[apps['domain']] + async def from_backend(): + try: + while True: + data = await backend.recv() + await ws.send_text(data) + except Exception: + await ws.close() - if (dt == True): - continue + await asyncio.gather(to_backend(), from_backend()) + except Exception as e: + print(f"[ws] Proxy error: {e}") + await ws.close(code=1011) + return Response(status_code=502, content=b"WebSocket proxy error") - diff_seconds = (datetime.now(timezone.utc) - dt).total_seconds() - if(diff_seconds > apps['proxy_timeout_seconds']): - print("stopping container: {0} ({1}) after {2}s".format(container['container_name'], container_object.id, diff_seconds)) - container_object.stop() - except docker.errors.NotFound: - pass - - time.sleep(sleep_time) -async def websocket_proxy(client_ws, target_host, target_port): - """ - Forwards WebSocket messages between the client and the target container. - """ - try: - async with websockets.connect(f"ws://{target_host}:{target_port}") as server_ws: - # Create tasks to read from both directions - async def forward_client_to_server(): - async for message in client_ws: - await server_ws.send(message) +# --- App Setup --- +app = FastAPI() +config = load_config("config.yml") +app.add_middleware(ReverseProxyMiddleware, routes=config) - async def forward_server_to_client(): - async for message in server_ws: - await client_ws.send(message) - - await asyncio.gather(forward_client_to_server(), forward_server_to_client()) - - except Exception as e: - print(f"WebSocket error: {e}") - finally: - await client_ws.close() - -# MAIN # - -if __name__ == '__main__': - activity = {} - - with open('config.yml', 'r') as file: - configuration = yaml.safe_load(file) - - docker_client = docker.from_env() - - t = BackgroundTasks(configuration, docker_client) - t.start() - - # Start the reverse proxy server on port 8888 - server_address = ('', configuration['proxy_port']) - proxy_handler = ProxyHandler(configuration, docker_client) - httpd = ThreadedHTTPServer(server_address, proxy_handler) - print('Reverse proxy server running on port {0}...'.format(configuration['proxy_port'])) - httpd.serve_forever() +# Optional: Health check +@app.get("/health") +async def health(): + return {"status": "ok"} diff --git a/requirements.txt b/requirements.txt index d89b982..9d3fee6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,6 @@ docker PyYAML -websockets \ No newline at end of file +websockets +fastapi +uvicorn[standard] +httpx \ No newline at end of file diff --git a/tests.py b/tests.py deleted file mode 100644 index 0b04e8a..0000000 --- a/tests.py +++ /dev/null @@ -1,13 +0,0 @@ -import asyncio -import time -import websockets - -async def main(): - async with websockets.connect('ws://localhost:8010') as ws: - while True: - await ws.send("testsage") - server_message = ws.recv() - print(server_message) - - -asyncio.get_event_loop().run_until_complete(main()) \ No newline at end of file