tilelang.utils.tensorΒΆ

The profiler and convert to torch utils

ClassesΒΆ

TensorSupplyType

Create a collection of name/value pairs.

FunctionsΒΆ

is_float8_dtype(dtype)

fp8_remove_negative_zeros_(tensor)

map_torch_type(intype)

get_tensor_supply([supply_type])

torch_assert_close(tensor_a, tensor_b[, rtol, atol, ...])

Custom function to assert that two tensors are "close enough," allowing a specified

Module ContentsΒΆ

tilelang.utils.tensor.is_float8_dtype(dtype)ΒΆ
Parameters:

dtype (torch.dtype)

Return type:

bool

tilelang.utils.tensor.fp8_remove_negative_zeros_(tensor)ΒΆ
Parameters:

tensor (torch.Tensor)

class tilelang.utils.tensor.TensorSupplyType(*args, **kwds)ΒΆ

Bases: enum.Enum

Create a collection of name/value pairs.

Example enumeration:

>>> class Color(Enum):
...     RED = 1
...     BLUE = 2
...     GREEN = 3

Access them by:

  • attribute access:

    >>> Color.RED
    <Color.RED: 1>
    
  • value lookup:

    >>> Color(1)
    <Color.RED: 1>
    
  • name lookup:

    >>> Color['RED']
    <Color.RED: 1>
    

Enumerations can be iterated over, and know how many members they have:

>>> len(Color)
3
>>> list(Color)
[<Color.RED: 1>, <Color.BLUE: 2>, <Color.GREEN: 3>]

Methods can be added to enumerations, and members can have their own attributes – see the documentation for details.

Integer = 1ΒΆ
Uniform = 2ΒΆ
Normal = 3ΒΆ
Randn = 4ΒΆ
Zero = 5ΒΆ
One = 6ΒΆ
Auto = 7ΒΆ
tilelang.utils.tensor.map_torch_type(intype)ΒΆ
Parameters:

intype (str)

Return type:

torch.dtype

tilelang.utils.tensor.get_tensor_supply(supply_type=TensorSupplyType.Integer)ΒΆ
Parameters:

supply_type (TensorSupplyType)

tilelang.utils.tensor.torch_assert_close(tensor_a, tensor_b, rtol=0.01, atol=0.001, max_mismatched_ratio=0.001, verbose=False, equal_nan=True, check_device=True, check_dtype=True, check_layout=True, check_stride=False, base_name='LHS', ref_name='RHS')ΒΆ

Custom function to assert that two tensors are β€œclose enough,” allowing a specified percentage of mismatched elements.

Parameters:ΒΆ

tensor_atorch.Tensor

The first tensor to compare.

tensor_btorch.Tensor

The second tensor to compare.

rtolfloat, optional

Relative tolerance for comparison. Default is 1e-2.

atolfloat, optional

Absolute tolerance for comparison. Default is 1e-3.

max_mismatched_ratiofloat, optional

Maximum ratio of mismatched elements allowed (relative to the total number of elements). Default is 0.001 (0.1% of total elements).

Raises:ΒΆ

AssertionError:

If the ratio of mismatched elements exceeds max_mismatched_ratio.

Parameters:
  • verbose (bool)

  • equal_nan (bool)

  • check_device (bool)

  • check_dtype (bool)

  • check_layout (bool)

  • check_stride (bool)

  • base_name (str)

  • ref_name (str)