tilelang.profiler

The profiler and convert to torch utils

子模組

類別

Profiler

A profiler class for benchmarking and validating kernel implementations.

Package Contents

class tilelang.profiler.Profiler

A profiler class for benchmarking and validating kernel implementations.

params

List of kernel parameters defining the input/output specifications

result_idx

Indices indicating which parameters are output tensors

supply_type

Type of tensor supply to use (e.g., random, zeros, etc.)

adapter

Optional kernel adapter for interfacing with different backends

params: list[tilelang.engine.param.KernelParam]
result_idx: list[int]
supply_type: tilelang.utils.tensor.TensorSupplyType
adapter: tilelang.jit.adapter.BaseKernelAdapter | None = None
__post_init__()

Initialize tensor supply after dataclass initialization

with_default_adapter(adapter)
參數:

adapter (tilelang.jit.adapter.BaseKernelAdapter)

回傳型別:

Profiler

assert_allclose(reference_program, input_tensors=None, atol=0.01, rtol=0.01, max_mismatched_ratio=0.01)

Validates kernel output against a reference implementation.

參數:
  • reference_program (Callable) -- Reference implementation to compare against

  • input_tensors (list[torch.Tensor] | None) -- Optional pre-generated input tensors

  • atol (float) -- Absolute tolerance for comparison

  • rtol (float) -- Relative tolerance for comparison

  • max_mismatched_ratio -- Maximum allowed ratio of mismatched elements

manual_assert_close(reference_program, input_tensors=None, manual_check_prog=None)

Validates kernel output against a reference implementation.

參數:
  • reference_program (Callable) -- Reference implementation to compare against

  • input_tensors (list[torch.Tensor] | None) -- Optional pre-generated input tensors

  • atol -- Absolute tolerance for comparison

  • rtol -- Relative tolerance for comparison

  • max_mismatched_ratio -- Maximum allowed ratio of mismatched elements

  • manual_check_prog (Callable)

assert_consistent(repeat=10)

Checks for kernel consistency across multiple runs.

參數:

repeat -- Number of times to repeat the consistency check

run_once(func=None)
參數:

func (Callable | None)

do_bench(func=None, warmup=25, rep=100, n_warmup=0, n_repeat=0, input_tensors=None, backend='event', quantiles=None, return_mode='mean', dynamic_symbolic_constraints=None)

Benchmarks the execution time of a given function.

參數:
  • func (Callable | None) -- Function to benchmark (uses adapter if None)

  • warmup (int) -- Warmup time in milliseconds

  • rep (int) -- Number of repetitions for timing

  • n_warmup (int) -- Number of warmup iterations

  • n_repeat (int) -- Number of timing iterations

  • backend (Literal['event', 'cupti', 'cudagraph']) -- Which profiling backend to use - "event", "cupti", or "cudagraph"

  • input_tensors (list[torch.Tensor]) -- Optional pre-generated input tensors

  • dynamic_symbolic_constraints (dict[str, int] | None) -- Optional dict mapping dynamic symbolic variable names to concrete int values. Use this when benchmarking kernels with dynamic shapes, e.g., {"m": 2048, "n": 1024}

  • quantiles (list[float] | None)

  • return_mode (Literal['min', 'max', 'mean', 'median'])

回傳:

Average execution time in milliseconds

回傳型別:

float

property func
__call__(*args, **kwds)
參數:
  • args (Any)

  • kwds (Any)

回傳型別:

Any