tilelang.utils.tensorΒΆ
The profiler and convert to torch utils
ClassesΒΆ
Create a collection of name/value pairs. |
FunctionsΒΆ
|
|
|
|
|
|
|
|
|
Custom function to assert that two tensors are "close enough," allowing a specified |
Module ContentsΒΆ
- tilelang.utils.tensor.fp8_remove_negative_zeros_(tensor)ΒΆ
- Parameters:
tensor (torch.Tensor)
- class tilelang.utils.tensor.TensorSupplyType(*args, **kwds)ΒΆ
Bases:
enum.EnumCreate a collection of name/value pairs.
Example enumeration:
>>> class Color(Enum): ... RED = 1 ... BLUE = 2 ... GREEN = 3
Access them by:
attribute access:
>>> Color.RED <Color.RED: 1>
value lookup:
>>> Color(1) <Color.RED: 1>
name lookup:
>>> Color['RED'] <Color.RED: 1>
Enumerations can be iterated over, and know how many members they have:
>>> len(Color) 3
>>> list(Color) [<Color.RED: 1>, <Color.BLUE: 2>, <Color.GREEN: 3>]
Methods can be added to enumerations, and members can have their own attributes β see the documentation for details.
- Integer = 1ΒΆ
- Uniform = 2ΒΆ
- Normal = 3ΒΆ
- Randn = 4ΒΆ
- Zero = 5ΒΆ
- One = 6ΒΆ
- Auto = 7ΒΆ
- tilelang.utils.tensor.map_torch_type(intype)ΒΆ
- Parameters:
intype (str)
- Return type:
torch.dtype
- tilelang.utils.tensor.get_tensor_supply(supply_type=TensorSupplyType.Integer)ΒΆ
- Parameters:
supply_type (TensorSupplyType)
- tilelang.utils.tensor.torch_assert_close(tensor_a, tensor_b, rtol=0.01, atol=0.001, max_mismatched_ratio=0.001, verbose=False, equal_nan=True, check_device=True, check_dtype=True, check_layout=True, check_stride=False, base_name='LHS', ref_name='RHS')ΒΆ
Custom function to assert that two tensors are βclose enough,β allowing a specified percentage of mismatched elements.
Parameters:ΒΆ
- tensor_atorch.Tensor
The first tensor to compare.
- tensor_btorch.Tensor
The second tensor to compare.
- rtolfloat, optional
Relative tolerance for comparison. Default is 1e-2.
- atolfloat, optional
Absolute tolerance for comparison. Default is 1e-3.
- max_mismatched_ratiofloat, optional
Maximum ratio of mismatched elements allowed (relative to the total number of elements). Default is 0.001 (0.1% of total elements).
Raises:ΒΆ
- AssertionError:
If the ratio of mismatched elements exceeds max_mismatched_ratio.