2025-08-25 18:58:31 +02:00

101 lines
3.3 KiB
Python

import httpx
from fastapi import Request, Depends, HTTPException
from fastapi.responses import Response
from starlette.responses import JSONResponse
from const import Hosts, HOP_BY_HOP_HEADERS, logger
from http_client_static_dns import AsyncCustomHost, NameSolver
# Configure upstream client
official_client = httpx.AsyncClient(
verify=True,
timeout=httpx.Timeout(30.0, connect=10.0),
limits=httpx.Limits(max_keepalive_connections=50, max_connections=100),
transport=AsyncCustomHost(NameSolver())
)
def host_required(allowed_hosts: list[Hosts]):
async def dependency(request: Request):
host = request.url.hostname
if host not in [h.value for h in allowed_hosts]:
logger.warning(f"Host '{host}' not allowed for url '{request.url}'")
raise HTTPException(status_code=403, detail=f"Host '{host}' not allowed")
return Depends(dependency)
async def call_official(request: Request, path: str) -> httpx.Response:
"""Forward the incoming request to the official API and return the response."""
logger.debug(f"Forwarding request to official: {request.method} {request.url}")
# Copy request body
body = await request.body()
# Copy headers, filtering hop-by-hop
req_headers = {
k: v
for k, v in request.headers.items()
if k.lower() not in HOP_BY_HOP_HEADERS
}
# Rewrite Host header to match upstream
req_headers["host"] = request.url.hostname
# Forward request upstream
upstream_response = await official_client.request(
method=request.method,
url=f"{str(request.base_url).rstrip('/').replace("http://", "https://")}/{path}",
headers=req_headers,
content=body,
params=request.query_params,
cookies=request.cookies,
)
# Filter response headers hop-by-hop
resp_headers = {
k: v
for k, v in upstream_response.headers.items()
if k.lower() not in HOP_BY_HOP_HEADERS
}
upstream_response.headers = resp_headers
return upstream_response
async def return_edited_response(
response: httpx.Response,
new_content: dict | bytes | None = None
) -> Response:
"""
Return possibly modified response.
If new_content is a dict and the response is JSON, it will merge/override keys.
If new_content and response data are lists, new_content will replace the entire response data.
If new_content is bytes or str, it will replace the response body entirely.
"""
content_type = response.headers.get("content-type", "")
response.headers.pop("content-encoding", None) # body is already decoded
if content_type.startswith("application/json"):
data = response.json()
if isinstance(new_content, dict): # allow overrides
data = {**data, **new_content}
if isinstance(data, list) and isinstance(new_content, list):
data = new_content # replace entire list
return JSONResponse(
content=data,
status_code=response.status_code,
headers=response.headers
)
else:
content = new_content if isinstance(new_content, (bytes, str)) else response.content
return Response(
content=content,
status_code=response.status_code,
headers=response.headers,
media_type=content_type
)