import asyncio import json import logging import socket import threading from collections.abc import Iterable from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from urllib.parse import parse_qs, urlparse from jinja2 import DictLoader, Environment from tabulate import tabulate from torch.distributed.debug._store import get_world_size, tcpstore_client from torch.distributed.flight_recorder.components.builder import build_db from torch.distributed.flight_recorder.components.config_manager import JobConfig from torch.distributed.flight_recorder.components.types import ( Collective, Group, Membership, NCCLCall, ) logger: logging.Logger = logging.getLogger(__name__) @dataclass(slots=True) class Response: status_code: int text: str def raise_for_status(self): if self.status_code != 200: raise RuntimeError(f"HTTP {self.status_code}: {self.text}") def json(self): return json.loads(self.text) def fetch_thread_pool(urls: list[str]) -> Iterable[Response]: # late import for optional dependency import requests max_workers = 20 def get(url: str) -> Response: resp = requests.post(url) return Response(resp.status_code, resp.text) with ThreadPoolExecutor(max_workers=max_workers) as executor: resps = executor.map(get, urls) return resps def fetch_aiohttp(urls: list[str]) -> Iterable[Response]: # late import for optional dependency import aiohttp async def fetch(session: aiohttp.ClientSession, url: str) -> Response: async with session.post(url) as resp: text = await resp.text() return Response(resp.status, text) async def gather(urls: list[str]) -> Iterable[Response]: async with aiohttp.ClientSession() as session: return await asyncio.gather(*[fetch(session, url) for url in urls]) return asyncio.run(gather(urls)) def fetch_all(endpoint: str, args: str = "") -> tuple[list[str], Iterable[Response]]: store = tcpstore_client() keys = [f"rank{r}" for r in range(get_world_size())] addrs = store.multi_get(keys) addrs = [f"{addr.decode()}/handler/{endpoint}?{args}" for addr in addrs] try: resps = fetch_aiohttp(addrs) except ImportError: resps = fetch_thread_pool(addrs) return addrs, resps def format_json(blob: str): parsed = json.loads(blob) return json.dumps(parsed, indent=2) templates = { "base.html": """ {% block title %}{% endblock %} - PyTorch Distributed
{% block header %}{% endblock %} {% block content %}{% endblock %}
""", "index.html": """ {% extends "base.html" %} {% block header %}

{% block title %}Index{% endblock %}

{% endblock %} {% block content %} Hi {% endblock %} """, "raw_resp.html": """ {% extends "base.html" %} {% block header %}

{% block title %}{{title}}{% endblock %}

{% endblock %} {% block content %} {% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}

Rank {{ i }}: {{ addr }}

{% if resp.status_code != 200 %}

Failed to fetch: status={{ resp.status_code }}

{{ resp.text }}
{% else %}
{{ resp.text }}
{% endif %} {% endfor %} {% endblock %} """, "json_resp.html": """ {% extends "base.html" %} {% block header %}

{% block title %}{{ title }}{% endblock %}

{% endblock %} {% block content %} {% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}

Rank {{ i }}: {{ addr }}

{% if resp.status_code != 200 %}

Failed to fetch: status={{ resp.status_code }}

{{ resp.text }}
{% else %}
{{ format_json(resp.text) }}
{% endif %} {% endfor %} {% endblock %} """, "profile.html": """ {% extends "base.html" %} {% block header %}

{% block title %}torch.profiler{% endblock %}

{% endblock %} {% block content %}
{% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}

Rank {{ i }}: {{ addr }}

{% if resp.status_code != 200 %}

Failed to fetch: status={{ resp.status_code }}

{{ resp.text }}
{% else %} {% endif %} {% endfor %} {% endblock %} """, "tcpstore.html": """ {% extends "base.html" %} {% block header %}

{% block title %}TCPStore Keys{% endblock %}

{% endblock %} {% block content %}
    {% for k, v in zip(keys, values) -%}
{{ k }}: {{ v | truncate(100) }}
    {% endfor %}
    
{% endblock %} """, "fr_trace.html": """ {% extends "base.html" %} {% block header %}

{% block title %}{{ title }}{% endblock %}

{% endblock %} {% block content %}

Groups

{{ groups | safe }}

Memberships

{{ memberships | safe }}

Collectives

{{ collectives | safe }}

NCCL Calls

{{ ncclcalls | safe }} {% endblock %} """, "pyspy_dump.html": """ {% extends "base.html" %} {% block header %}

{% block title %}py-spy Stack Traces{% endblock %}

{% endblock %} {% block content %}
{% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}

Rank {{ i }}: {{ addr }}

{% if resp.status_code != 200 %}

Failed to fetch: status={{ resp.status_code }}

{{ resp.text }}
{% else %}
{{ resp.text }}
{% endif %} {% endfor %} {% endblock %} """, } class _IPv6HTTPServer(ThreadingHTTPServer): address_family: socket.AddressFamily = socket.AF_INET6 # pyre-ignore request_queue_size: int = 1024 class HTTPRequestHandler(BaseHTTPRequestHandler): frontend: "FrontendServer" def log_message(self, format, *args): logger.info( "%s %s", self.client_address[0], format % args, ) def do_GET(self): self.frontend._handle_request(self) def get_path(self) -> str: return urlparse(self.path).path def get_query(self) -> dict[str, list[str]]: return parse_qs(self.get_raw_query()) def get_raw_query(self) -> str: return urlparse(self.path).query def get_query_arg( self, name: str, default: object = None, type: type = str ) -> object: query = self.get_query() if name not in query: return default return type(query[name][0]) class FrontendServer: def __init__(self, port: int): # Setup templates loader = DictLoader(templates) self._jinja_env = Environment(loader=loader, enable_async=True) self._jinja_env.globals.update( zip=zip, format_json=format_json, enumerate=enumerate, ) # Create routes self._routes = { "/": self._handle_index, "/stacks": self._handle_stacks, "/pyspy_dump": self._handle_pyspy_dump, "/fr_trace": self._handle_fr_trace, "/fr_trace_json": self._handle_fr_trace_json, "/fr_trace_nccl": self._handle_fr_trace_nccl, "/fr_trace_nccl_json": self._handle_fr_trace_nccl_json, "/profile": self._handle_profiler, "/wait_counters": self._handle_wait_counters, "/tcpstore": self._handle_tcpstore, } # Create HTTP server RequestHandlerClass = type( "HTTPRequestHandler", (HTTPRequestHandler,), {"frontend": self}, ) server_address = ("", port) self._server = _IPv6HTTPServer(server_address, RequestHandlerClass) self._thread = threading.Thread( target=self._serve, args=(), daemon=True, name="distributed.debug.FrontendServer", ) self._thread.start() def _serve(self) -> None: try: self._server.serve_forever() except Exception: logger.exception("got exception in frontend server") def join(self) -> None: self._thread.join() def _handle_request(self, req: HTTPRequestHandler) -> None: path = req.get_path() if path not in self._routes: req.send_error(404, f"Handler not found: {path}") return handler = self._routes[path] try: resp = handler(req) # Catch SystemExit to not crash when FlightRecorder errors. except (Exception, SystemExit) as e: logger.exception( "Exception in frontend server when handling %s", path, ) req.send_error(500, f"Exception: {repr(e)}") return req.send_response(200) req.send_header("Content-type", "text/html") req.end_headers() req.wfile.write(resp) def _render_template(self, template: str, **kwargs: object) -> bytes: return self._jinja_env.get_template(template).render(**kwargs).encode() def _handle_index(self, req: HTTPRequestHandler) -> bytes: return self._render_template("index.html") def _handle_stacks(self, req: HTTPRequestHandler) -> bytes: addrs, resps = fetch_all("dump_traceback") return self._render_template( "raw_resp.html", title="Stacks", addrs=addrs, resps=resps ) def _handle_pyspy_dump(self, req: HTTPRequestHandler) -> bytes: addrs, resps = fetch_all("pyspy_dump", req.get_raw_query()) return self._render_template( "pyspy_dump.html", addrs=addrs, resps=resps, ) def _render_fr_trace(self, addrs: list[str], resps: list[Response]) -> bytes: config = JobConfig() # pyrefly: ignore [bad-assignment] args = config.parse_args(args=[]) args.allow_incomplete_ranks = True args.verbose = True details = {} for rank, resp in enumerate(resps): resp.raise_for_status() dump = { "rank": rank, "host_name": addrs[rank], **resp.json(), } if "entries" not in dump: dump["entries"] = [] details[f"rank{rank}.json"] = dump version = next(iter(details.values()))["version"] db = build_db(details, args, version) return self._render_template( "fr_trace.html", title="FlightRecorder", groups=tabulate(db.groups, headers=Group._fields, tablefmt="html"), memberships=tabulate( db.memberships, headers=Membership._fields, tablefmt="html" ), collectives=tabulate( db.collectives, headers=Collective._fields, tablefmt="html" ), ncclcalls=tabulate(db.ncclcalls, headers=NCCLCall._fields, tablefmt="html"), ) def _handle_fr_trace(self, req: HTTPRequestHandler) -> bytes: addrs, resps = fetch_all("fr_trace_json") return self._render_fr_trace(addrs, list(resps)) def _handle_fr_trace_json(self, req: HTTPRequestHandler) -> bytes: addrs, resps = fetch_all("fr_trace_json") return self._render_template( "json_resp.html", title="FlightRecorder", addrs=addrs, resps=resps, ) def _handle_fr_trace_nccl(self, req: HTTPRequestHandler) -> bytes: addrs, resps = fetch_all("dump_nccl_trace_json", "onlyactive=true") return self._render_fr_trace(addrs, list(resps)) def _handle_fr_trace_nccl_json(self, req: HTTPRequestHandler) -> bytes: addrs, resps = fetch_all("dump_nccl_trace_json", "onlyactive=true") return self._render_template( "json_resp.html", title="FlightRecorder NCCL", addrs=addrs, resps=resps, ) def _handle_profiler(self, req: HTTPRequestHandler) -> bytes: duration = req.get_query_arg("duration", default=1.0, type=float) addrs, resps = fetch_all("torch_profile", f"duration={duration}") return self._render_template("profile.html", addrs=addrs, resps=resps) def _handle_wait_counters(self, req: HTTPRequestHandler) -> bytes: addrs, resps = fetch_all("wait_counter_values") return self._render_template( "json_resp.html", title="Wait Counters", addrs=addrs, resps=resps ) def _handle_tcpstore(self, req: HTTPRequestHandler) -> bytes: store = tcpstore_client(prefix="") keys = store.list_keys() keys.sort() values = [repr(v) for v in store.multi_get(keys)] return self._render_template("tcpstore.html", keys=keys, values=values) def main(port: int) -> None: logger.setLevel(logging.INFO) server = FrontendServer(port=port) logger.info("Frontend server started on port %d", server._server.server_port) server.join()