tilelang.utils.tensor¶
The profiler and convert to torch utils
類別¶
Create a collection of name/value pairs. |
函式¶
|
|
|
|
|
|
|
|
|
Custom function to assert that two tensors are "close enough," allowing a specified |
Module Contents¶
- tilelang.utils.tensor.fp8_remove_negative_zeros_(tensor)¶
- 參數:
tensor (torch.Tensor)
- class tilelang.utils.tensor.TensorSupplyType(*args, **kwds)¶
Bases:
enum.EnumCreate a collection of name/value pairs.
Example enumeration:
>>> class Color(Enum): ... RED = 1 ... BLUE = 2 ... GREEN = 3
Access them by:
attribute access:
>>> Color.RED <Color.RED: 1>
value lookup:
>>> Color(1) <Color.RED: 1>
name lookup:
>>> Color['RED'] <Color.RED: 1>
Enumerations can be iterated over, and know how many members they have:
>>> len(Color) 3
>>> list(Color) [<Color.RED: 1>, <Color.BLUE: 2>, <Color.GREEN: 3>]
Methods can be added to enumerations, and members can have their own attributes -- see the documentation for details.
- Integer = 1¶
- Uniform = 2¶
- Normal = 3¶
- Randn = 4¶
- Zero = 5¶
- One = 6¶
- Auto = 7¶
- tilelang.utils.tensor.map_torch_type(intype)¶
- 回傳型別:
torch.dtype
- tilelang.utils.tensor.get_tensor_supply(supply_type=TensorSupplyType.Integer)¶
- 參數:
supply_type (TensorSupplyType)
- tilelang.utils.tensor.torch_assert_close(tensor_a, tensor_b, rtol=0.01, atol=0.001, max_mismatched_ratio=0.001, verbose=False, equal_nan=True, check_device=True, check_dtype=True, check_layout=True, check_stride=False, base_name='LHS', ref_name='RHS')¶
Custom function to assert that two tensors are "close enough," allowing a specified percentage of mismatched elements.
Parameters:¶
- tensor_atorch.Tensor
The first tensor to compare.
- tensor_btorch.Tensor
The second tensor to compare.
- rtolfloat, optional
Relative tolerance for comparison. Default is 1e-2.
- atolfloat, optional
Absolute tolerance for comparison. Default is 1e-3.
- max_mismatched_ratiofloat, optional
Maximum ratio of mismatched elements allowed (relative to the total number of elements). Default is 0.001 (0.1% of total elements).
Raises:¶
- AssertionError:
If the ratio of mismatched elements exceeds max_mismatched_ratio.