You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

152 lines
4.3 KiB

4 days ago
import math
import os
import torch
from torch._inductor.utils import get_device_tflops, get_gpu_dram_gbps
from torch.utils._ordered_set import OrderedSet
from .flop_counter import flop_registry
aten = torch.ops.aten
_FLOAT_TYPES = OrderedSet(
[
torch.float16,
torch.bfloat16,
torch.float32,
torch.float64,
]
)
# This value is hard-coded here:
# https://github.com/pytorch/pytorch/blob/5fba5d83f0703ff8077ab65448a998e9ad6598fd/c10/cuda/CUDACachingAllocator.cpp#L117
_PYTORCH_MIN_ALLOCATE = (
2**9 if int(os.environ.get("PYTORCH_NO_CUDA_MEMORY_CACHING", 0)) == 0 else 1
)
# No fall-back kernel needed/exists for view ops
_VIEW_OPS = OrderedSet(
[
aten.lift_fresh,
aten.t,
aten.transpose,
aten.view,
aten.detach,
aten._unsafe_view,
aten.split,
aten.adjoint,
aten.as_strided,
aten.diagonal,
aten.expand,
aten.expand_as,
aten.movedim,
aten.permute,
aten.select,
aten.squeeze,
aten.mT,
aten.mH,
aten.real,
aten.imag,
aten.view_as,
aten.unflatten,
aten.unfold,
aten.unbind,
aten.unsqueeze,
aten.vsplit,
aten.hsplit,
aten.split_with_sizes,
aten.swapaxes,
aten.swapdims,
aten.chunk,
]
)
# We can ignore benchmarking tensor create ops
_CREATE_OPS = OrderedSet(
[
aten.randint,
aten.randn,
aten.rand,
aten.randn_like,
aten.rand_like,
aten.randint_like,
aten.arange,
aten.ones_like,
aten.zeros_like,
]
)
_IGNORE_OPS = _VIEW_OPS | _CREATE_OPS
def get_compute_time(func_packet, args, kwargs, out, out_dtypes) -> float: # type: ignore[no-untyped-def]
"""
Estimates the compute time of an aten operator.
Args:
func_packet: The operator overload packet.
args: The arguments to the operator.
kwargs: The keyword arguments to the operator.
out: The output of the operator.
out_dtypes: The output data types.
Returns:
float: The estimated compute time in nanoseconds.
"""
if func_packet in flop_registry:
assert len(out_dtypes) == 1, (
f"Only support single out dtype got {out_dtypes} for {func_packet}"
)
dtype = out_dtypes.pop()
# This actually gives peta-FLOPs/s hence multiply by 1e15 to get the FLOPs/s
peak_gpu_flops = get_device_tflops(dtype) * 1e15
# We can expect to achieve 75% of theoretical peak flops
factor = 0.75
peak_empirical_flops = factor * peak_gpu_flops
flop_count_func = flop_registry[func_packet]
# We divide by a factor of 2 to get the MACs (multiply and accumulate)
flop_count = flop_count_func(*args, **kwargs, out_val=out) / 2
# We multiply by 1e9 to get the time in nano seconds
compute_time = (flop_count / peak_empirical_flops) * 1e9
return compute_time
return 0.0
def get_num_bytes(t: torch.Tensor) -> int:
"""
Calculates the memory consumption of a tensor.
Args:
t (torch.Tensor): The input tensor.
Returns:
int: The memory consumption of the tensor in bytes.
"""
num_bytes = t.untyped_storage().nbytes()
mem_consumed = math.ceil(num_bytes / _PYTORCH_MIN_ALLOCATE) * _PYTORCH_MIN_ALLOCATE
return mem_consumed
def get_transfer_time(flat_args_kwargs, flat_outs) -> float: # type: ignore[no-untyped-def]
"""
Estimates the memory transfer time of input and output tensors.
Args:
flat_args_kwargs (List[torch.Tensor]): The flat list of arguments and keyword arguments.
flat_outs (List[torch.Tensor]): The flat list of outputs.
Returns:
float: The estimated memory transfer time in nanoseconds.
"""
gpu_memory_bandwidth = get_gpu_dram_gbps()
read_bytes = sum(
get_num_bytes(t) for t in flat_args_kwargs if isinstance(t, torch.Tensor)
)
write_bytes = sum(
get_num_bytes(t) for t in flat_outs if isinstance(t, torch.Tensor)
)
counted_bytes = read_bytes + write_bytes
# The GPU memory bandwidth is in GB/s so the transfer time is in nanoseconds
transfer_time = counted_bytes / gpu_memory_bandwidth
return transfer_time