You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
554 lines
17 KiB
554 lines
17 KiB
|
5 days ago
|
import asyncio
|
||
|
|
import json
|
||
|
|
import logging
|
||
|
|
import socket
|
||
|
|
import threading
|
||
|
|
from collections.abc import Iterable
|
||
|
|
from concurrent.futures import ThreadPoolExecutor
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||
|
|
from urllib.parse import parse_qs, urlparse
|
||
|
|
|
||
|
|
from jinja2 import DictLoader, Environment
|
||
|
|
from tabulate import tabulate
|
||
|
|
|
||
|
|
from torch.distributed.debug._store import get_world_size, tcpstore_client
|
||
|
|
from torch.distributed.flight_recorder.components.builder import build_db
|
||
|
|
from torch.distributed.flight_recorder.components.config_manager import JobConfig
|
||
|
|
from torch.distributed.flight_recorder.components.types import (
|
||
|
|
Collective,
|
||
|
|
Group,
|
||
|
|
Membership,
|
||
|
|
NCCLCall,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
logger: logging.Logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(slots=True)
|
||
|
|
class Response:
|
||
|
|
status_code: int
|
||
|
|
text: str
|
||
|
|
|
||
|
|
def raise_for_status(self):
|
||
|
|
if self.status_code != 200:
|
||
|
|
raise RuntimeError(f"HTTP {self.status_code}: {self.text}")
|
||
|
|
|
||
|
|
def json(self):
|
||
|
|
return json.loads(self.text)
|
||
|
|
|
||
|
|
|
||
|
|
def fetch_thread_pool(urls: list[str]) -> Iterable[Response]:
|
||
|
|
# late import for optional dependency
|
||
|
|
import requests
|
||
|
|
|
||
|
|
max_workers = 20
|
||
|
|
|
||
|
|
def get(url: str) -> Response:
|
||
|
|
resp = requests.post(url)
|
||
|
|
return Response(resp.status_code, resp.text)
|
||
|
|
|
||
|
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||
|
|
resps = executor.map(get, urls)
|
||
|
|
|
||
|
|
return resps
|
||
|
|
|
||
|
|
|
||
|
|
def fetch_aiohttp(urls: list[str]) -> Iterable[Response]:
|
||
|
|
# late import for optional dependency
|
||
|
|
import aiohttp
|
||
|
|
|
||
|
|
async def fetch(session: aiohttp.ClientSession, url: str) -> Response:
|
||
|
|
async with session.post(url) as resp:
|
||
|
|
text = await resp.text()
|
||
|
|
return Response(resp.status, text)
|
||
|
|
|
||
|
|
async def gather(urls: list[str]) -> Iterable[Response]:
|
||
|
|
async with aiohttp.ClientSession() as session:
|
||
|
|
return await asyncio.gather(*[fetch(session, url) for url in urls])
|
||
|
|
|
||
|
|
return asyncio.run(gather(urls))
|
||
|
|
|
||
|
|
|
||
|
|
def fetch_all(endpoint: str, args: str = "") -> tuple[list[str], Iterable[Response]]:
|
||
|
|
store = tcpstore_client()
|
||
|
|
keys = [f"rank{r}" for r in range(get_world_size())]
|
||
|
|
addrs = store.multi_get(keys)
|
||
|
|
addrs = [f"{addr.decode()}/handler/{endpoint}?{args}" for addr in addrs]
|
||
|
|
|
||
|
|
try:
|
||
|
|
resps = fetch_aiohttp(addrs)
|
||
|
|
except ImportError:
|
||
|
|
resps = fetch_thread_pool(addrs)
|
||
|
|
|
||
|
|
return addrs, resps
|
||
|
|
|
||
|
|
|
||
|
|
def format_json(blob: str):
|
||
|
|
parsed = json.loads(blob)
|
||
|
|
return json.dumps(parsed, indent=2)
|
||
|
|
|
||
|
|
|
||
|
|
templates = {
|
||
|
|
"base.html": """
|
||
|
|
<!doctype html>
|
||
|
|
<head>
|
||
|
|
<title>{% block title %}{% endblock %} - PyTorch Distributed</title>
|
||
|
|
<link rel="shortcut icon" type="image/x-icon" href="https://pytorch.org/favicon.ico?">
|
||
|
|
|
||
|
|
<style>
|
||
|
|
body {
|
||
|
|
margin: 0;
|
||
|
|
font-family:
|
||
|
|
-apple-system,BlinkMacSystemFont,"Segoe UI",Roboto,
|
||
|
|
"Helvetica Neue",Arial,"Noto Sans",sans-serif,"Apple Color Emoji",
|
||
|
|
"Segoe UI Emoji","Segoe UI Symbol","Noto Color Emoji";
|
||
|
|
font-size: 1rem;
|
||
|
|
font-weight: 400;
|
||
|
|
line-height: 1.5;
|
||
|
|
color: #212529;
|
||
|
|
text-align: left;
|
||
|
|
background-color: #fff;
|
||
|
|
}
|
||
|
|
h1, h2, h2, h4, h5, h6, .h1, .h2, .h2, .h4, .h5, .h6 {
|
||
|
|
margin-bottom: .5rem;
|
||
|
|
font-weight: 500;
|
||
|
|
line-height: 1.2;
|
||
|
|
}
|
||
|
|
nav {
|
||
|
|
background-color: rgba(0, 0, 0, 0.17);
|
||
|
|
padding: 10px;
|
||
|
|
display: flex;
|
||
|
|
align-items: center;
|
||
|
|
padding: 16px;
|
||
|
|
justify-content: flex-start;
|
||
|
|
}
|
||
|
|
nav h1 {
|
||
|
|
display: inline-block;
|
||
|
|
margin: 0;
|
||
|
|
}
|
||
|
|
nav a {
|
||
|
|
margin: 0 8px;
|
||
|
|
}
|
||
|
|
section {
|
||
|
|
max-width: 1280px;
|
||
|
|
padding: 16px;
|
||
|
|
margin: 0 auto;
|
||
|
|
}
|
||
|
|
pre {
|
||
|
|
white-space: pre-wrap;
|
||
|
|
max-width: 100%;
|
||
|
|
}
|
||
|
|
</style>
|
||
|
|
</head>
|
||
|
|
|
||
|
|
<nav>
|
||
|
|
<h1>Torch Distributed Debug Server</h1>
|
||
|
|
|
||
|
|
<a href="/">Home</a> <!--@lint-ignore-->
|
||
|
|
<a href="/stacks">Python Stack Traces</a> <!--@lint-ignore-->
|
||
|
|
<a href="/pyspy_dump">py-spy Stacks</a> <!--@lint-ignore-->
|
||
|
|
<a href="/fr_trace">FlightRecorder CPU</a> <!--@lint-ignore-->
|
||
|
|
<a href="/fr_trace_json">(JSON)</a> <!--@lint-ignore-->
|
||
|
|
<a href="/fr_trace_nccl">FlightRecorder NCCL</a> <!--@lint-ignore-->
|
||
|
|
<a href="/fr_trace_nccl_json">(JSON)</a> <!--@lint-ignore-->
|
||
|
|
<a href="/profile">torch profiler</a> <!--@lint-ignore-->
|
||
|
|
<a href="/wait_counters">Wait Counters</a> <!--@lint-ignore-->
|
||
|
|
<a href="/tcpstore">TCPStore</a> <!--@lint-ignore-->
|
||
|
|
</nav>
|
||
|
|
|
||
|
|
<section class="content">
|
||
|
|
{% block header %}{% endblock %}
|
||
|
|
{% block content %}{% endblock %}
|
||
|
|
</section>
|
||
|
|
""",
|
||
|
|
"index.html": """
|
||
|
|
{% extends "base.html" %}
|
||
|
|
{% block header %}
|
||
|
|
<h1>{% block title %}Index{% endblock %}</h1>
|
||
|
|
{% endblock %}
|
||
|
|
{% block content %}
|
||
|
|
Hi
|
||
|
|
{% endblock %}
|
||
|
|
""",
|
||
|
|
"raw_resp.html": """
|
||
|
|
{% extends "base.html" %}
|
||
|
|
{% block header %}
|
||
|
|
<h1>{% block title %}{{title}}{% endblock %}</h1>
|
||
|
|
{% endblock %}
|
||
|
|
{% block content %}
|
||
|
|
{% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}
|
||
|
|
<h2>Rank {{ i }}: {{ addr }}</h2>
|
||
|
|
{% if resp.status_code != 200 %}
|
||
|
|
<p>Failed to fetch: status={{ resp.status_code }}</p>
|
||
|
|
<pre>{{ resp.text }}</pre>
|
||
|
|
{% else %}
|
||
|
|
<pre>{{ resp.text }}</pre>
|
||
|
|
{% endif %}
|
||
|
|
{% endfor %}
|
||
|
|
{% endblock %}
|
||
|
|
""",
|
||
|
|
"json_resp.html": """
|
||
|
|
{% extends "base.html" %}
|
||
|
|
{% block header %}
|
||
|
|
<h1>{% block title %}{{ title }}{% endblock %}</h1>
|
||
|
|
{% endblock %}
|
||
|
|
{% block content %}
|
||
|
|
{% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}
|
||
|
|
<h2>Rank {{ i }}: {{ addr }}</h2>
|
||
|
|
{% if resp.status_code != 200 %}
|
||
|
|
<p>Failed to fetch: status={{ resp.status_code }}</p>
|
||
|
|
<pre>{{ resp.text }}</pre>
|
||
|
|
{% else %}
|
||
|
|
<pre>{{ format_json(resp.text) }}</pre>
|
||
|
|
{% endif %}
|
||
|
|
{% endfor %}
|
||
|
|
{% endblock %}
|
||
|
|
""",
|
||
|
|
"profile.html": """
|
||
|
|
{% extends "base.html" %}
|
||
|
|
{% block header %}
|
||
|
|
<h1>{% block title %}torch.profiler{% endblock %}</h1>
|
||
|
|
{% endblock %}
|
||
|
|
|
||
|
|
{% block content %}
|
||
|
|
<form action="" method="get">
|
||
|
|
<label for="duration">Duration (seconds):</label>
|
||
|
|
<input type="number" id="duration" name="duration" value="{{ duration }}" min="1" max="60">
|
||
|
|
<input type="submit" value="Submit">
|
||
|
|
</form>
|
||
|
|
|
||
|
|
<script>
|
||
|
|
function stringToArrayBuffer(str) {
|
||
|
|
const encoder = new TextEncoder();
|
||
|
|
return encoder.encode(str).buffer;
|
||
|
|
}
|
||
|
|
async function openPerfetto(data) {
|
||
|
|
const ui = window.open('https://ui.perfetto.dev/#!/');
|
||
|
|
if (!ui) { alert('Popup blocked. Allow popups for this page and click again.'); return; }
|
||
|
|
|
||
|
|
// Perfetto readiness handshake: PING until we receive PONG
|
||
|
|
await new Promise((resolve, reject) => {
|
||
|
|
const onMsg = (e) => {
|
||
|
|
if (e.source === ui && e.data === 'PONG') {
|
||
|
|
window.removeEventListener('message', onMsg);
|
||
|
|
clearInterval(pinger);
|
||
|
|
resolve();
|
||
|
|
}
|
||
|
|
};
|
||
|
|
window.addEventListener('message', onMsg);
|
||
|
|
const pinger = setInterval(() => { try { ui.postMessage('PING', '*'); } catch (_e) {} }, 250);
|
||
|
|
setTimeout(() => { clearInterval(pinger); window.removeEventListener('message', onMsg); reject(); }, 20000);
|
||
|
|
}).catch(() => { alert('Perfetto UI did not respond. Try again.'); return; });
|
||
|
|
|
||
|
|
ui.postMessage({
|
||
|
|
perfetto: {
|
||
|
|
buffer: stringToArrayBuffer(JSON.stringify(data)),
|
||
|
|
title: "torch profiler",
|
||
|
|
fileName: "trace.json",
|
||
|
|
}
|
||
|
|
}, '*');
|
||
|
|
}
|
||
|
|
</script>
|
||
|
|
|
||
|
|
{% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}
|
||
|
|
<h2>Rank {{ i }}: {{ addr }}</h2>
|
||
|
|
{% if resp.status_code != 200 %}
|
||
|
|
<p>Failed to fetch: status={{ resp.status_code }}</p>
|
||
|
|
<pre>{{ resp.text }}</pre>
|
||
|
|
{% else %}
|
||
|
|
<script>
|
||
|
|
function run{{ i }}() {
|
||
|
|
var data = {{ resp.text | safe }};
|
||
|
|
openPerfetto(data);
|
||
|
|
}
|
||
|
|
</script>
|
||
|
|
|
||
|
|
<button onclick="run{{ i }}()">View {{ i }}</button>
|
||
|
|
{% endif %}
|
||
|
|
{% endfor %}
|
||
|
|
{% endblock %}
|
||
|
|
""",
|
||
|
|
"tcpstore.html": """
|
||
|
|
{% extends "base.html" %}
|
||
|
|
{% block header %}
|
||
|
|
<h1>{% block title %}TCPStore Keys{% endblock %}</h1>
|
||
|
|
{% endblock %}
|
||
|
|
{% block content %}
|
||
|
|
<pre>
|
||
|
|
{% for k, v in zip(keys, values) -%}
|
||
|
|
{{ k }}: {{ v | truncate(100) }}
|
||
|
|
{% endfor %}
|
||
|
|
</pre>
|
||
|
|
{% endblock %}
|
||
|
|
""",
|
||
|
|
"fr_trace.html": """
|
||
|
|
{% extends "base.html" %}
|
||
|
|
{% block header %}
|
||
|
|
<h1>{% block title %}{{ title }}{% endblock %}</h1>
|
||
|
|
{% endblock %}
|
||
|
|
{% block content %}
|
||
|
|
<h2>Groups</h2>
|
||
|
|
{{ groups | safe }}
|
||
|
|
<h2>Memberships</h2>
|
||
|
|
{{ memberships | safe }}
|
||
|
|
<h2>Collectives</h2>
|
||
|
|
{{ collectives | safe }}
|
||
|
|
<h2>NCCL Calls</h2>
|
||
|
|
{{ ncclcalls | safe }}
|
||
|
|
{% endblock %}
|
||
|
|
""",
|
||
|
|
"pyspy_dump.html": """
|
||
|
|
{% extends "base.html" %}
|
||
|
|
{% block header %}
|
||
|
|
<h1>{% block title %}py-spy Stack Traces{% endblock %}</h1>
|
||
|
|
{% endblock %}
|
||
|
|
{% block content %}
|
||
|
|
<form action="" method="get">
|
||
|
|
<input type="checkbox" id="native" name="native" value="1"/>
|
||
|
|
<label for="native">Native</label>
|
||
|
|
<input type="checkbox" id="subprocesses" name="subprocesses" value="1"/>
|
||
|
|
<label for="subprocesses">Subprocesses</label>
|
||
|
|
<input type="submit" value="Submit">
|
||
|
|
</form>
|
||
|
|
|
||
|
|
{% for i, (addr, resp) in enumerate(zip(addrs, resps)) %}
|
||
|
|
<h2>Rank {{ i }}: {{ addr }}</h2>
|
||
|
|
{% if resp.status_code != 200 %}
|
||
|
|
<p>Failed to fetch: status={{ resp.status_code }}</p>
|
||
|
|
<pre>{{ resp.text }}</pre>
|
||
|
|
{% else %}
|
||
|
|
<pre>{{ resp.text }}</pre>
|
||
|
|
{% endif %}
|
||
|
|
{% endfor %}
|
||
|
|
{% endblock %}
|
||
|
|
""",
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
class _IPv6HTTPServer(ThreadingHTTPServer):
|
||
|
|
address_family: socket.AddressFamily = socket.AF_INET6 # pyre-ignore
|
||
|
|
request_queue_size: int = 1024
|
||
|
|
|
||
|
|
|
||
|
|
class HTTPRequestHandler(BaseHTTPRequestHandler):
|
||
|
|
frontend: "FrontendServer"
|
||
|
|
|
||
|
|
def log_message(self, format, *args):
|
||
|
|
logger.info(
|
||
|
|
"%s %s",
|
||
|
|
self.client_address[0],
|
||
|
|
format % args,
|
||
|
|
)
|
||
|
|
|
||
|
|
def do_GET(self):
|
||
|
|
self.frontend._handle_request(self)
|
||
|
|
|
||
|
|
def get_path(self) -> str:
|
||
|
|
return urlparse(self.path).path
|
||
|
|
|
||
|
|
def get_query(self) -> dict[str, list[str]]:
|
||
|
|
return parse_qs(self.get_raw_query())
|
||
|
|
|
||
|
|
def get_raw_query(self) -> str:
|
||
|
|
return urlparse(self.path).query
|
||
|
|
|
||
|
|
def get_query_arg(
|
||
|
|
self, name: str, default: object = None, type: type = str
|
||
|
|
) -> object:
|
||
|
|
query = self.get_query()
|
||
|
|
if name not in query:
|
||
|
|
return default
|
||
|
|
return type(query[name][0])
|
||
|
|
|
||
|
|
|
||
|
|
class FrontendServer:
|
||
|
|
def __init__(self, port: int):
|
||
|
|
# Setup templates
|
||
|
|
loader = DictLoader(templates)
|
||
|
|
self._jinja_env = Environment(loader=loader, enable_async=True)
|
||
|
|
self._jinja_env.globals.update(
|
||
|
|
zip=zip,
|
||
|
|
format_json=format_json,
|
||
|
|
enumerate=enumerate,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Create routes
|
||
|
|
self._routes = {
|
||
|
|
"/": self._handle_index,
|
||
|
|
"/stacks": self._handle_stacks,
|
||
|
|
"/pyspy_dump": self._handle_pyspy_dump,
|
||
|
|
"/fr_trace": self._handle_fr_trace,
|
||
|
|
"/fr_trace_json": self._handle_fr_trace_json,
|
||
|
|
"/fr_trace_nccl": self._handle_fr_trace_nccl,
|
||
|
|
"/fr_trace_nccl_json": self._handle_fr_trace_nccl_json,
|
||
|
|
"/profile": self._handle_profiler,
|
||
|
|
"/wait_counters": self._handle_wait_counters,
|
||
|
|
"/tcpstore": self._handle_tcpstore,
|
||
|
|
}
|
||
|
|
|
||
|
|
# Create HTTP server
|
||
|
|
RequestHandlerClass = type(
|
||
|
|
"HTTPRequestHandler",
|
||
|
|
(HTTPRequestHandler,),
|
||
|
|
{"frontend": self},
|
||
|
|
)
|
||
|
|
|
||
|
|
server_address = ("", port)
|
||
|
|
self._server = _IPv6HTTPServer(server_address, RequestHandlerClass)
|
||
|
|
|
||
|
|
self._thread = threading.Thread(
|
||
|
|
target=self._serve,
|
||
|
|
args=(),
|
||
|
|
daemon=True,
|
||
|
|
name="distributed.debug.FrontendServer",
|
||
|
|
)
|
||
|
|
self._thread.start()
|
||
|
|
|
||
|
|
def _serve(self) -> None:
|
||
|
|
try:
|
||
|
|
self._server.serve_forever()
|
||
|
|
except Exception:
|
||
|
|
logger.exception("got exception in frontend server")
|
||
|
|
|
||
|
|
def join(self) -> None:
|
||
|
|
self._thread.join()
|
||
|
|
|
||
|
|
def _handle_request(self, req: HTTPRequestHandler) -> None:
|
||
|
|
path = req.get_path()
|
||
|
|
if path not in self._routes:
|
||
|
|
req.send_error(404, f"Handler not found: {path}")
|
||
|
|
return
|
||
|
|
|
||
|
|
handler = self._routes[path]
|
||
|
|
try:
|
||
|
|
resp = handler(req)
|
||
|
|
# Catch SystemExit to not crash when FlightRecorder errors.
|
||
|
|
except (Exception, SystemExit) as e:
|
||
|
|
logger.exception(
|
||
|
|
"Exception in frontend server when handling %s",
|
||
|
|
path,
|
||
|
|
)
|
||
|
|
req.send_error(500, f"Exception: {repr(e)}")
|
||
|
|
return
|
||
|
|
|
||
|
|
req.send_response(200)
|
||
|
|
req.send_header("Content-type", "text/html")
|
||
|
|
req.end_headers()
|
||
|
|
req.wfile.write(resp)
|
||
|
|
|
||
|
|
def _render_template(self, template: str, **kwargs: object) -> bytes:
|
||
|
|
return self._jinja_env.get_template(template).render(**kwargs).encode()
|
||
|
|
|
||
|
|
def _handle_index(self, req: HTTPRequestHandler) -> bytes:
|
||
|
|
return self._render_template("index.html")
|
||
|
|
|
||
|
|
def _handle_stacks(self, req: HTTPRequestHandler) -> bytes:
|
||
|
|
addrs, resps = fetch_all("dump_traceback")
|
||
|
|
return self._render_template(
|
||
|
|
"raw_resp.html", title="Stacks", addrs=addrs, resps=resps
|
||
|
|
)
|
||
|
|
|
||
|
|
def _handle_pyspy_dump(self, req: HTTPRequestHandler) -> bytes:
|
||
|
|
addrs, resps = fetch_all("pyspy_dump", req.get_raw_query())
|
||
|
|
return self._render_template(
|
||
|
|
"pyspy_dump.html",
|
||
|
|
addrs=addrs,
|
||
|
|
resps=resps,
|
||
|
|
)
|
||
|
|
|
||
|
|
def _render_fr_trace(self, addrs: list[str], resps: list[Response]) -> bytes:
|
||
|
|
config = JobConfig()
|
||
|
|
# pyrefly: ignore [bad-assignment]
|
||
|
|
args = config.parse_args(args=[])
|
||
|
|
args.allow_incomplete_ranks = True
|
||
|
|
args.verbose = True
|
||
|
|
|
||
|
|
details = {}
|
||
|
|
for rank, resp in enumerate(resps):
|
||
|
|
resp.raise_for_status()
|
||
|
|
dump = {
|
||
|
|
"rank": rank,
|
||
|
|
"host_name": addrs[rank],
|
||
|
|
**resp.json(),
|
||
|
|
}
|
||
|
|
if "entries" not in dump:
|
||
|
|
dump["entries"] = []
|
||
|
|
details[f"rank{rank}.json"] = dump
|
||
|
|
|
||
|
|
version = next(iter(details.values()))["version"]
|
||
|
|
|
||
|
|
db = build_db(details, args, version)
|
||
|
|
|
||
|
|
return self._render_template(
|
||
|
|
"fr_trace.html",
|
||
|
|
title="FlightRecorder",
|
||
|
|
groups=tabulate(db.groups, headers=Group._fields, tablefmt="html"),
|
||
|
|
memberships=tabulate(
|
||
|
|
db.memberships, headers=Membership._fields, tablefmt="html"
|
||
|
|
),
|
||
|
|
collectives=tabulate(
|
||
|
|
db.collectives, headers=Collective._fields, tablefmt="html"
|
||
|
|
),
|
||
|
|
ncclcalls=tabulate(db.ncclcalls, headers=NCCLCall._fields, tablefmt="html"),
|
||
|
|
)
|
||
|
|
|
||
|
|
def _handle_fr_trace(self, req: HTTPRequestHandler) -> bytes:
|
||
|
|
addrs, resps = fetch_all("fr_trace_json")
|
||
|
|
|
||
|
|
return self._render_fr_trace(addrs, list(resps))
|
||
|
|
|
||
|
|
def _handle_fr_trace_json(self, req: HTTPRequestHandler) -> bytes:
|
||
|
|
addrs, resps = fetch_all("fr_trace_json")
|
||
|
|
|
||
|
|
return self._render_template(
|
||
|
|
"json_resp.html",
|
||
|
|
title="FlightRecorder",
|
||
|
|
addrs=addrs,
|
||
|
|
resps=resps,
|
||
|
|
)
|
||
|
|
|
||
|
|
def _handle_fr_trace_nccl(self, req: HTTPRequestHandler) -> bytes:
|
||
|
|
addrs, resps = fetch_all("dump_nccl_trace_json", "onlyactive=true")
|
||
|
|
|
||
|
|
return self._render_fr_trace(addrs, list(resps))
|
||
|
|
|
||
|
|
def _handle_fr_trace_nccl_json(self, req: HTTPRequestHandler) -> bytes:
|
||
|
|
addrs, resps = fetch_all("dump_nccl_trace_json", "onlyactive=true")
|
||
|
|
|
||
|
|
return self._render_template(
|
||
|
|
"json_resp.html",
|
||
|
|
title="FlightRecorder NCCL",
|
||
|
|
addrs=addrs,
|
||
|
|
resps=resps,
|
||
|
|
)
|
||
|
|
|
||
|
|
def _handle_profiler(self, req: HTTPRequestHandler) -> bytes:
|
||
|
|
duration = req.get_query_arg("duration", default=1.0, type=float)
|
||
|
|
|
||
|
|
addrs, resps = fetch_all("torch_profile", f"duration={duration}")
|
||
|
|
|
||
|
|
return self._render_template("profile.html", addrs=addrs, resps=resps)
|
||
|
|
|
||
|
|
def _handle_wait_counters(self, req: HTTPRequestHandler) -> bytes:
|
||
|
|
addrs, resps = fetch_all("wait_counter_values")
|
||
|
|
return self._render_template(
|
||
|
|
"json_resp.html", title="Wait Counters", addrs=addrs, resps=resps
|
||
|
|
)
|
||
|
|
|
||
|
|
def _handle_tcpstore(self, req: HTTPRequestHandler) -> bytes:
|
||
|
|
store = tcpstore_client(prefix="")
|
||
|
|
keys = store.list_keys()
|
||
|
|
keys.sort()
|
||
|
|
values = [repr(v) for v in store.multi_get(keys)]
|
||
|
|
return self._render_template("tcpstore.html", keys=keys, values=values)
|
||
|
|
|
||
|
|
|
||
|
|
def main(port: int) -> None:
|
||
|
|
logger.setLevel(logging.INFO)
|
||
|
|
|
||
|
|
server = FrontendServer(port=port)
|
||
|
|
logger.info("Frontend server started on port %d", server._server.server_port)
|
||
|
|
server.join()
|