import httpx from fastapi import Request, Depends, HTTPException from fastapi.responses import Response from starlette.responses import JSONResponse from const import Hosts, HOP_BY_HOP_HEADERS, logger from http_client_static_dns import AsyncCustomHost, NameSolver # Configure upstream client official_client = httpx.AsyncClient( verify=True, timeout=httpx.Timeout(30.0, connect=10.0), limits=httpx.Limits(max_keepalive_connections=50, max_connections=100), transport=AsyncCustomHost(NameSolver()) ) def host_required(allowed_hosts: list[Hosts]): async def dependency(request: Request): host = request.url.hostname if host != "localhost" and host != "127.0.0.1": if host not in [h.value for h in allowed_hosts]: logger.warning(f"Host '{host}' not allowed for url '{request.url}'") raise HTTPException(status_code=403, detail=f"Host '{host}' not allowed") return Depends(dependency) async def call_official(request: Request, path: str) -> httpx.Response: """Forward the incoming request to the official API and return the response.""" logger.debug(f"Forwarding request to official: {request.method} {request.url}") # Copy request body body = await request.body() # Copy headers, filtering hop-by-hop req_headers = { k: v for k, v in request.headers.items() if k.lower() not in HOP_BY_HOP_HEADERS } # Rewrite Host header to match upstream req_headers["host"] = request.url.hostname # Forward request upstream upstream_response = await official_client.request( method=request.method, url=f"{str(request.base_url).rstrip('/').replace("http://", "https://")}/{path}", headers=req_headers, content=body, params=request.query_params, cookies=request.cookies, ) # Filter response headers hop-by-hop resp_headers = { k: v for k, v in upstream_response.headers.items() if k.lower() not in HOP_BY_HOP_HEADERS } upstream_response.headers = resp_headers return upstream_response async def return_edited_response( response: httpx.Response, new_content: dict | bytes | list | str | None = None, ignore_official_data: bool = False ) -> Response: """ Return possibly modified response. If new_content is a dict and the response is JSON, it will merge/override keys. If new_content and response data are lists, new_content will replace the entire response data. If new_content is bytes or str, it will replace the response body entirely. """ content_type = response.headers.get("content-type", "") response.headers.pop("content-encoding", None) # body is already decoded if content_type.startswith("application/json"): data = response.json() if isinstance(new_content, dict): # allow overrides data = new_content if ignore_official_data else {**data, **new_content} if (isinstance(data, list) or ignore_official_data) and isinstance(new_content, list): data = new_content # replace entire list return JSONResponse( content=data, status_code=response.status_code, headers=response.headers ) else: content = new_content if isinstance(new_content, (bytes, str)) else response.content return Response( content=content, status_code=response.status_code, headers=response.headers, media_type=content_type ) def fake_response(content: dict | None = None, status_code: int = 200) -> Response: headers = { "Access-Control-Allow-Origin": "*", # request.headers["origin"] "Access-Control-Allow-Methods": "POST, OPTIONS", "Access-Control-Allow-Headers": "*", } if content is None: return Response(status_code=status_code, headers=headers) return JSONResponse(content=content, status_code=status_code, headers=headers)