import asyncio import json import logging import socket import threading from collections.abc import Iterable from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from urllib.parse import parse_qs, urlparse from jinja2 import DictLoader, Environment from tabulate import tabulate from torch.distributed.debug._store import get_world_size, tcpstore_client from torch.distributed.flight_recorder.components.builder import build_db from torch.distributed.flight_recorder.components.config_manager import JobConfig from torch.distributed.flight_recorder.components.types import ( Collective, Group, Membership, NCCLCall, ) logger: logging.Logger = logging.getLogger(__name__) @dataclass(slots=True) class Response: status_code: int text: str def raise_for_status(self): if self.status_code != 200: raise RuntimeError(f"HTTP {self.status_code}: {self.text}") def json(self): return json.loads(self.text) def fetch_thread_pool(urls: list[str]) -> Iterable[Response]: # late import for optional dependency import requests max_workers = 20 def get(url: str) -> Response: resp = requests.post(url) return Response(resp.status_code, resp.text) with ThreadPoolExecutor(max_workers=max_workers) as executor: resps = executor.map(get, urls) return resps def fetch_aiohttp(urls: list[str]) -> Iterable[Response]: # late import for optional dependency import aiohttp async def fetch(session: aiohttp.ClientSession, url: str) -> Response: async with session.post(url) as resp: text = await resp.text() return Response(resp.status, text) async def gather(urls: list[str]) -> Iterable[Response]: async with aiohttp.ClientSession() as session: return await asyncio.gather(*[fetch(session, url) for url in urls]) return asyncio.run(gather(urls)) def fetch_all(endpoint: str, args: str = "") -> tuple[list[str], Iterable[Response]]: store = tcpstore_client() keys = [f"rank{r}" for r in range(get_world_size())] addrs = store.multi_get(keys) addrs = [f"{addr.decode()}/handler/{endpoint}?{args}" for addr in addrs] try: resps = fetch_aiohttp(addrs) except ImportError: resps = fetch_thread_pool(addrs) return addrs, resps def format_json(blob: str): parsed = json.loads(blob) return json.dumps(parsed, indent=2) templates = { "base.html": """
Failed to fetch: status={{ resp.status_code }}
{{ resp.text }}
{% else %}
{{ resp.text }}
{% endif %}
{% endfor %}
{% endblock %}
""",
"json_resp.html": """
{% extends "base.html" %}
{% block header %}
Failed to fetch: status={{ resp.status_code }}
{{ resp.text }}
{% else %}
{{ format_json(resp.text) }}
{% endif %}
{% endfor %}
{% endblock %}
""",
"profile.html": """
{% extends "base.html" %}
{% block header %}
Failed to fetch: status={{ resp.status_code }}
{{ resp.text }}
{% else %}
{% endif %}
{% endfor %}
{% endblock %}
""",
"tcpstore.html": """
{% extends "base.html" %}
{% block header %}
{% for k, v in zip(keys, values) -%}
{{ k }}: {{ v | truncate(100) }}
{% endfor %}
{% endblock %}
""",
"fr_trace.html": """
{% extends "base.html" %}
{% block header %}
Failed to fetch: status={{ resp.status_code }}
{{ resp.text }}
{% else %}
{{ resp.text }}
{% endif %}
{% endfor %}
{% endblock %}
""",
}
class _IPv6HTTPServer(ThreadingHTTPServer):
address_family: socket.AddressFamily = socket.AF_INET6 # pyre-ignore
request_queue_size: int = 1024
class HTTPRequestHandler(BaseHTTPRequestHandler):
frontend: "FrontendServer"
def log_message(self, format, *args):
logger.info(
"%s %s",
self.client_address[0],
format % args,
)
def do_GET(self):
self.frontend._handle_request(self)
def get_path(self) -> str:
return urlparse(self.path).path
def get_query(self) -> dict[str, list[str]]:
return parse_qs(self.get_raw_query())
def get_raw_query(self) -> str:
return urlparse(self.path).query
def get_query_arg(
self, name: str, default: object = None, type: type = str
) -> object:
query = self.get_query()
if name not in query:
return default
return type(query[name][0])
class FrontendServer:
def __init__(self, port: int):
# Setup templates
loader = DictLoader(templates)
self._jinja_env = Environment(loader=loader, enable_async=True)
self._jinja_env.globals.update(
zip=zip,
format_json=format_json,
enumerate=enumerate,
)
# Create routes
self._routes = {
"/": self._handle_index,
"/stacks": self._handle_stacks,
"/pyspy_dump": self._handle_pyspy_dump,
"/fr_trace": self._handle_fr_trace,
"/fr_trace_json": self._handle_fr_trace_json,
"/fr_trace_nccl": self._handle_fr_trace_nccl,
"/fr_trace_nccl_json": self._handle_fr_trace_nccl_json,
"/profile": self._handle_profiler,
"/wait_counters": self._handle_wait_counters,
"/tcpstore": self._handle_tcpstore,
}
# Create HTTP server
RequestHandlerClass = type(
"HTTPRequestHandler",
(HTTPRequestHandler,),
{"frontend": self},
)
server_address = ("", port)
self._server = _IPv6HTTPServer(server_address, RequestHandlerClass)
self._thread = threading.Thread(
target=self._serve,
args=(),
daemon=True,
name="distributed.debug.FrontendServer",
)
self._thread.start()
def _serve(self) -> None:
try:
self._server.serve_forever()
except Exception:
logger.exception("got exception in frontend server")
def join(self) -> None:
self._thread.join()
def _handle_request(self, req: HTTPRequestHandler) -> None:
path = req.get_path()
if path not in self._routes:
req.send_error(404, f"Handler not found: {path}")
return
handler = self._routes[path]
try:
resp = handler(req)
# Catch SystemExit to not crash when FlightRecorder errors.
except (Exception, SystemExit) as e:
logger.exception(
"Exception in frontend server when handling %s",
path,
)
req.send_error(500, f"Exception: {repr(e)}")
return
req.send_response(200)
req.send_header("Content-type", "text/html")
req.end_headers()
req.wfile.write(resp)
def _render_template(self, template: str, **kwargs: object) -> bytes:
return self._jinja_env.get_template(template).render(**kwargs).encode()
def _handle_index(self, req: HTTPRequestHandler) -> bytes:
return self._render_template("index.html")
def _handle_stacks(self, req: HTTPRequestHandler) -> bytes:
addrs, resps = fetch_all("dump_traceback")
return self._render_template(
"raw_resp.html", title="Stacks", addrs=addrs, resps=resps
)
def _handle_pyspy_dump(self, req: HTTPRequestHandler) -> bytes:
addrs, resps = fetch_all("pyspy_dump", req.get_raw_query())
return self._render_template(
"pyspy_dump.html",
addrs=addrs,
resps=resps,
)
def _render_fr_trace(self, addrs: list[str], resps: list[Response]) -> bytes:
config = JobConfig()
# pyrefly: ignore [bad-assignment]
args = config.parse_args(args=[])
args.allow_incomplete_ranks = True
args.verbose = True
details = {}
for rank, resp in enumerate(resps):
resp.raise_for_status()
dump = {
"rank": rank,
"host_name": addrs[rank],
**resp.json(),
}
if "entries" not in dump:
dump["entries"] = []
details[f"rank{rank}.json"] = dump
version = next(iter(details.values()))["version"]
db = build_db(details, args, version)
return self._render_template(
"fr_trace.html",
title="FlightRecorder",
groups=tabulate(db.groups, headers=Group._fields, tablefmt="html"),
memberships=tabulate(
db.memberships, headers=Membership._fields, tablefmt="html"
),
collectives=tabulate(
db.collectives, headers=Collective._fields, tablefmt="html"
),
ncclcalls=tabulate(db.ncclcalls, headers=NCCLCall._fields, tablefmt="html"),
)
def _handle_fr_trace(self, req: HTTPRequestHandler) -> bytes:
addrs, resps = fetch_all("fr_trace_json")
return self._render_fr_trace(addrs, list(resps))
def _handle_fr_trace_json(self, req: HTTPRequestHandler) -> bytes:
addrs, resps = fetch_all("fr_trace_json")
return self._render_template(
"json_resp.html",
title="FlightRecorder",
addrs=addrs,
resps=resps,
)
def _handle_fr_trace_nccl(self, req: HTTPRequestHandler) -> bytes:
addrs, resps = fetch_all("dump_nccl_trace_json", "onlyactive=true")
return self._render_fr_trace(addrs, list(resps))
def _handle_fr_trace_nccl_json(self, req: HTTPRequestHandler) -> bytes:
addrs, resps = fetch_all("dump_nccl_trace_json", "onlyactive=true")
return self._render_template(
"json_resp.html",
title="FlightRecorder NCCL",
addrs=addrs,
resps=resps,
)
def _handle_profiler(self, req: HTTPRequestHandler) -> bytes:
duration = req.get_query_arg("duration", default=1.0, type=float)
addrs, resps = fetch_all("torch_profile", f"duration={duration}")
return self._render_template("profile.html", addrs=addrs, resps=resps)
def _handle_wait_counters(self, req: HTTPRequestHandler) -> bytes:
addrs, resps = fetch_all("wait_counter_values")
return self._render_template(
"json_resp.html", title="Wait Counters", addrs=addrs, resps=resps
)
def _handle_tcpstore(self, req: HTTPRequestHandler) -> bytes:
store = tcpstore_client(prefix="")
keys = store.list_keys()
keys.sort()
values = [repr(v) for v in store.multi_get(keys)]
return self._render_template("tcpstore.html", keys=keys, values=values)
def main(port: int) -> None:
logger.setLevel(logging.INFO)
server = FrontendServer(port=port)
logger.info("Frontend server started on port %d", server._server.server_port)
server.join()