128 lines
4.6 KiB
Python

import json
import httpx
from fastapi import Request, Depends, HTTPException
from fastapi.responses import Response
from starlette.responses import JSONResponse
from const import Hosts, HOP_BY_HOP_HEADERS, logger, LOG_REQUESTS, LOG_FILE
from http_client_static_dns import AsyncCustomHost, NameSolver
# Configure upstream client
official_client = httpx.AsyncClient(
verify=True,
timeout=httpx.Timeout(30.0, connect=10.0),
limits=httpx.Limits(max_keepalive_connections=50, max_connections=100),
transport=AsyncCustomHost(NameSolver())
)
def host_required(allowed_hosts: list[Hosts]):
async def dependency(request: Request):
host = request.url.hostname
if host != "localhost" and host != "127.0.0.1":
if host not in [h.value for h in allowed_hosts]:
logger.warning(f"Host '{host}' not allowed for url '{request.url}'")
raise HTTPException(status_code=403, detail=f"Host '{host}' not allowed")
return Depends(dependency)
async def call_official(request: Request, path: str) -> httpx.Response:
"""Forward the incoming request to the official API and return the response."""
logger.debug(f"Forwarding request to official: {request.method} {request.url}")
# Copy request body
body = await request.body()
# Copy headers, filtering hop-by-hop
req_headers = {
k: v
for k, v in request.headers.items()
if k.lower() not in HOP_BY_HOP_HEADERS
}
# Rewrite Host header to match upstream
req_headers["host"] = request.url.hostname
# Forward request upstream
upstream_response = await official_client.request(
method=request.method,
url=f"{str(request.base_url).rstrip('/')}/{path}", # .replace("http://", "https://")
headers=req_headers,
content=body,
params=request.query_params,
follow_redirects=True,
cookies=request.cookies,
)
# Filter response headers hop-by-hop
resp_headers = {
k: v
for k, v in upstream_response.headers.items()
if k.lower() not in HOP_BY_HOP_HEADERS
}
upstream_response.headers = resp_headers
return upstream_response
async def return_edited_response(
response: httpx.Response,
new_content: dict | bytes | list | str | None = None,
ignore_official_data: bool = False
) -> Response:
"""
Return possibly modified response.
If new_content is a dict and the response is JSON, it will merge/override keys.
If new_content and response data are lists, new_content will replace the entire response data.
If new_content is bytes or str, it will replace the response body entirely.
"""
content_type = response.headers.get("content-type", "")
response.headers.pop("content-encoding", None) # body is already decoded
if content_type.startswith("application/json"):
data = response.json()
if isinstance(new_content, dict): # allow overrides
data = new_content if ignore_official_data else {**data, **new_content}
if (isinstance(data, list) or ignore_official_data) and isinstance(new_content, list):
data = new_content # replace entire list
if LOG_REQUESTS:
with open(LOG_FILE, "a", encoding="utf-8") as f:
beautify_headers = json.dumps(dict(response.headers), indent=2)
beautify_request_headers = json.dumps(dict(response.request.headers), indent=2)
beautify_data = json.dumps(data, indent=2, ensure_ascii=False)
f.write(f"---\n{response.status_code} {response.url}\n\nRequest Headers: {beautify_request_headers}\n\n"
f"Response Headers: {beautify_headers}\n\nResponse Data: {beautify_data}\n\n\n\n")
return JSONResponse(
content=data,
status_code=response.status_code,
headers=response.headers
)
else:
content = new_content if isinstance(new_content, (bytes, str)) else response.content
return Response(
content=content,
status_code=response.status_code,
headers=response.headers,
media_type=content_type
)
def fake_response(content: dict | None = None, status_code: int = 200) -> Response:
headers = {
"Access-Control-Allow-Origin": "*", # request.headers["origin"]
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "*",
}
if content is None:
return Response(status_code=status_code, headers=headers)
return JSONResponse(content=content, status_code=status_code, headers=headers)