tilelang.tileop.gemm.gemm_base ============================== .. py:module:: tilelang.tileop.gemm.gemm_base Classes ------- .. autoapisummary:: tilelang.tileop.gemm.gemm_base.GemmBase Module Contents --------------- .. py:class:: GemmBase Base class for GEMM tile operators. Classifies the GEMM variant by the memory scopes of operands A and B (SS, SR, RS, TS, RR) and provides common property accessors for the underlying ``gemm_node`` IR node. .. py:attribute:: gemm_node :type: tvm.ir.base.Node .. py:method:: infer_layout(target, thread_nums) :abstractmethod: .. py:method:: lower(layout_map, target, thread_bounds, thread_var) :abstractmethod: .. py:method:: is_gemm_ss() Return True if both A and B are in shared memory (SS variant). .. py:method:: is_gemm_sr() Return True if A is in shared memory and B is in registers (SR variant). .. py:method:: is_gemm_rs() Return True if A is in registers and B is in shared memory (RS variant). .. py:method:: is_gemm_ts() Return True if A is in tensor memory and B is in shared memory (TS variant). .. py:method:: is_gemm_rr() Return True if both A and B are in registers (RR variant). .. py:property:: M :type: int .. py:property:: N :type: int .. py:property:: K :type: int .. py:property:: trans_A :type: bool .. py:property:: trans_B :type: bool .. py:property:: in_dtype :type: str Input data type for the multiplication. For the TS variant, A resides in TMEM with the accumulator dtype, so the actual input dtype is derived from B. .. py:property:: accum_dtype :type: str .. py:property:: chunk :type: int .. py:property:: A :type: tvm.tir.Buffer .. py:property:: B :type: tvm.tir.Buffer .. py:property:: C :type: tvm.tir.Buffer .. py:property:: ARegion .. py:property:: BRegion .. py:property:: CRegion .. py:property:: stride_A :type: int .. py:property:: stride_B :type: int .. py:property:: offset_A :type: int .. py:property:: offset_B :type: int .. py:property:: clear_accum :type: tvm.ir.PrimExpr .. py:property:: k_pack :type: int .. py:property:: wg_wait :type: int .. py:property:: policy :type: tilelang.tileop.base.GemmWarpPolicy .. py:property:: mbarptr :type: tvm.ir.PrimExpr .. py:property:: mbar :type: tvm.tir.BufferLoad | None .. py:property:: C_coords .. py:method:: get_region_base_offsets(region) Get the base offset (start index) for each dimension from a BufferRegion. For example, if region is A_shared[ko % 2, 0:128, 0:64], this returns [ko % 2, 0, 0] :param region: BufferRegion object :returns: List of PrimExpr representing the base offset for each dimension .. py:property:: A_base_offsets Get base offsets for each dimension of A region .. py:property:: B_base_offsets Get base offsets for each dimension of B region .. py:property:: C_base_offsets Get base offsets for each dimension of C region