101 lines
3.2 KiB
Python
101 lines
3.2 KiB
Python
import httpx
|
|
from fastapi import Request, Depends, HTTPException
|
|
from fastapi.responses import Response
|
|
from starlette.responses import JSONResponse
|
|
|
|
from const import Hosts, HOP_BY_HOP_HEADERS
|
|
from http_client_static_dns import AsyncCustomHost, NameSolver
|
|
|
|
# Configure upstream client
|
|
official_client = httpx.AsyncClient(
|
|
verify=True,
|
|
timeout=httpx.Timeout(30.0, connect=10.0),
|
|
limits=httpx.Limits(max_keepalive_connections=50, max_connections=100),
|
|
transport=AsyncCustomHost(NameSolver())
|
|
)
|
|
|
|
|
|
def host_required(allowed_hosts: list[Hosts]):
|
|
async def dependency(request: Request):
|
|
host = request.url.hostname
|
|
if host not in [h.value for h in allowed_hosts]:
|
|
raise HTTPException(status_code=403, detail=f"Host '{host}' not allowed")
|
|
|
|
return Depends(dependency)
|
|
|
|
|
|
async def call_official(request: Request, path: str) -> httpx.Response:
|
|
"""Forward the incoming request to the official API and return the response."""
|
|
print(f"Processing request to {str(request.url)}")
|
|
|
|
# Copy request body
|
|
body = await request.body()
|
|
|
|
# Copy headers, filtering hop-by-hop
|
|
req_headers = {
|
|
k: v
|
|
for k, v in request.headers.items()
|
|
if k.lower() not in HOP_BY_HOP_HEADERS
|
|
}
|
|
|
|
# Rewrite Host header to match upstream
|
|
req_headers["host"] = request.url.hostname
|
|
|
|
# Forward request upstream
|
|
upstream_response = await official_client.request(
|
|
method=request.method,
|
|
url=f"{str(request.base_url).rstrip('/').replace("http://", "https://")}/{path}",
|
|
headers=req_headers,
|
|
content=body,
|
|
params=request.query_params,
|
|
cookies=request.cookies,
|
|
)
|
|
|
|
# Filter response headers hop-by-hop
|
|
resp_headers = {
|
|
k: v
|
|
for k, v in upstream_response.headers.items()
|
|
if k.lower() not in HOP_BY_HOP_HEADERS
|
|
}
|
|
upstream_response.headers = resp_headers
|
|
|
|
return upstream_response
|
|
|
|
|
|
async def return_edited_response(
|
|
response: httpx.Response,
|
|
new_content: dict | bytes | None = None
|
|
) -> Response:
|
|
"""
|
|
Return possibly modified response.
|
|
|
|
If new_content is a dict and the response is JSON, it will merge/override keys.
|
|
If new_content and response data are lists, new_content will replace the entire response data.
|
|
If new_content is bytes or str, it will replace the response body entirely.
|
|
"""
|
|
|
|
content_type = response.headers.get("content-type", "")
|
|
response.headers.pop("content-encoding", None) # body is already decoded
|
|
|
|
if content_type.startswith("application/json"):
|
|
data = response.json()
|
|
|
|
if isinstance(new_content, dict): # allow overrides
|
|
data = {**data, **new_content}
|
|
if isinstance(data, list) and isinstance(new_content, list):
|
|
data = new_content # replace entire list
|
|
|
|
return JSONResponse(
|
|
content=data,
|
|
status_code=response.status_code,
|
|
headers=response.headers
|
|
)
|
|
else:
|
|
content = new_content if isinstance(new_content, (bytes, str)) else response.content
|
|
return Response(
|
|
content=content,
|
|
status_code=response.status_code,
|
|
headers=response.headers,
|
|
media_type=content_type
|
|
)
|