tilelang.intrinsics.wmma_layout¶

Layout functions for AMD RDNA WMMA instructions (gfx11/gfx12).

EMPIRICALLY VERIFIED hardware layouts for wmma_f32_16x16x16_f16_w32_gfx12:

A[M=16][K=16]:

thread t, elem e -> A[M=t%16][K=(t//16)*8+e] Forward: (M, K) -> (thread=(K//8)*16+M, local=K%8) Reverse: (thread, local) -> (M=thread%16, K=(thread//16)*8+local) Memory load: A[M=t%16][K=(t//16)*8..+7] -> CONTIGUOUS in K (vectorized)

B[K=16][N=16] (non-transposed, K x N storage):

thread t, elem e -> B[K=(t//16)*8+e][N=t%16] Forward: (K, N) -> (thread=(K//8)*16+N, local=K%8) Reverse: (thread, local) -> (K=(thread//16)*8+local, N=thread%16)

B_T[N=16][K=16] (transposed storage of B):

B_T[N=t%16][K=(t//16)*8+e] -> CONTIGUOUS in K (vectorized)

D[M=16][N=16]:

thread t, elem l -> D[M=(t//16)*8+l][N=t%16] Forward: (M, N) -> (thread=(M//8)*16+N, local=M%8) Reverse: (thread, local) -> (M=(thread//16)*8+local, N=thread%16) Store: D[M=(t//16)*8+l][N=t%16] = d_vec[l]

EMPIRICALLY VERIFIED hardware layouts for wmma_f32_16x16x16_f16_w32 (gfx11):

A[M=16][K=16]:

thread t, elem e -> A[M=t%16][K=e] Forward: (M, K) -> (thread=M, local=K%16) [Mapping to tid=0~15] Reverse: (thread, local) -> (M=thread%16, K=local) Memory load: A[M=t%16][K=0..+15] -> CONTIGUOUS in K (vectorized)

B[K=16][N=16] (non-transposed, K x N storage):

thread t, elem e -> B[K=e][N=t%16] Forward: (K, N) -> (thread=N, local=K%16) [Mapping to tid=0~15] Reverse: (thread, local) -> (K=local, N=thread%16)

B_T[N=16][K=16] (transposed storage of B):

B_T[N=t%16][K=e] -> CONTIGUOUS in K (vectorized)

D[M=16][N=16]:

thread t, elem l -> D[M=(t//16)+l*2][N=t%16] Forward: (M, N) -> (thread=(M%2)*16+N, local=M//2) Reverse: (thread, local) -> (M=(thread//16)+local*2, N=thread%16) Store: D[M=(t//16)+l*2][N=t%16] = d_vec[l]

NOTE: 1. A and D have DIFFERENT layouts (e.g. For gfx12, A uses t%16 for M,

D uses (t//16)*8+l for M). This means they cannot be used interchangeably without a layout change.

  1. For gfx11, lane 16~31 share the same A/B data as lane 0~15.

local_size = 8 (gfx12) | 16 (gfx11)

Functions¶

shared_16x16_to_local_32x8_layout_A_gfx12(i, j)

Forward: A[i=M, j=K] -> (thread=(j//8)*16+i, local=j%8).

thread_id_shared_access_32x8_to_16x16_layout_A_gfx12(...)

Reverse: (thread, local) -> (i=M=thread%16, j=K=(thread//16)*8+local).

shared_16x16_to_local_32x8_layout_A_colmajor_gfx12(i, j)

Forward: A_T[i=K, j=M] -> (thread=(i//8)*16+j, local=i%8).

thread_id_shared_access_32x8_to_16x16_layout_A_colmajor_gfx12(...)

Reverse: (thread, local) -> (i=K=(thread//16)*8+local, j=M=thread%16).

shared_16x16_to_local_32x8_layout_B_gfx12(i, j)

Forward: B[i=K, j=N] -> (thread=(i//8)*16+j, local=i%8).

thread_id_shared_access_32x8_to_16x16_layout_B_gfx12(...)

Reverse: (thread, local) -> (i=K=(thread//16)*8+local, j=N=thread%16).

shared_16x16_to_local_32x8_layout_B_colmajor_gfx12(i, j)

Forward: B_T[i=N, j=K] -> (thread=(j//8)*16+i, local=j%8).

thread_id_shared_access_32x8_to_16x16_layout_B_colmajor_gfx12(...)

Reverse: (thread, local) -> (i=N=thread%16, j=K=(thread//16)*8+local).

shared_16x16_to_local_32x8_layout_C_gfx12(i, j)

Forward: D[i=M, j=N] -> (thread=(i//8)*16+j, local=i%8).

thread_id_shared_access_32x8_to_16x16_layout_C_gfx12(...)

Reverse: (thread, local) -> (i=M=(thread//16)*8+local, j=N=thread%16).

wmma_store_index_map_gfx12(thread_id, local_id)

(thread, local) -> (M, N) in D. Hardware D layout.

shared_16x16_to_local_32x16_layout_A_gfx11(i, j)

Forward: A[i=M, j=K] -> (thread=i, local=j%16).

thread_id_shared_access_32x16_to_16x16_layout_A_gfx11(...)

Reverse: (thread, local) -> (i=M=thread%16, j=K=local)

shared_16x16_to_local_32x16_layout_A_colmajor_gfx11(i, j)

Forward: A_T[i=K, j=M] -> (thread=M, local=K%16).

thread_id_shared_access_32x16_to_16x16_layout_A_colmajor_gfx11(...)

Reverse: (thread, local) -> (i=K=local, j=M=thread%16)

shared_16x16_to_local_32x16_layout_B_gfx11(i, j)

Forward: B[i=K, j=N] -> (thread=N, local=K%16).

thread_id_shared_access_32x16_to_16x16_layout_B_gfx11(...)

Reverse: (thread, local) -> (i=K=local, j=N=thread%16)

shared_16x16_to_local_32x16_layout_B_colmajor_gfx11(i, j)

Forward: B_T[i=N, j=K] -> (thread=i, local=j%16).

thread_id_shared_access_32x16_to_16x16_layout_B_colmajor_gfx11(...)

Reverse: (thread, local) -> (j=K=local, i=N=thread%16)

shared_16x16_to_local_32x8_layout_C_gfx11(i, j)

Forward: D[i=M, j=N] -> (thread=(i%2)*16+j, local=i//2).

thread_id_shared_access_32x8_to_16x16_layout_C_gfx11(...)

Reverse: (thread, local) -> (i=M=(thread//16)+local*2, j=N=thread%16)

wmma_store_index_map_gfx11(thread_id, local_id)

(thread, local) -> (M, N) in D. Hardware D layout.

fragment_forward_A_gfx11(i, j, rep)

Replicated fragment forward map for gfx11 A.

fragment_forward_A_colmajor_gfx11(i, j, rep)

Replicated fragment forward map for gfx11 transposed A.

fragment_forward_B_gfx11(i, j, rep)

Replicated fragment forward map for gfx11 B.

fragment_forward_B_colmajor_gfx11(i, j, rep)

Replicated fragment forward map for gfx11 transposed B.

get_wmma_a_layout_funcs(rdna_gen, transposed)

Return (forward_map, reverse_map) for A layout.

get_wmma_b_layout_funcs(rdna_gen, transposed)

Return (forward_map, reverse_map) for B layout.

get_wmma_c_layout_funcs(rdna_gen)

Return (forward_map, reverse_map) for C/D layout.

get_wmma_store_index_map_func(rdna_gen)

Return the (thread_id, local_id) -> (row, col) store map.

get_wmma_a_fragment_forward_func(rdna_gen, transposed)

Return the fragment forward function for A layout.

get_wmma_b_fragment_forward_func(rdna_gen, transposed)

Return the fragment forward function for B layout.

get_wmma_fragment_replicate_count(rdna_gen)

Return the fragment replicate count used for logical one-to-many owners.

Module Contents¶

tilelang.intrinsics.wmma_layout.shared_16x16_to_local_32x8_layout_A_gfx12(i, j)¶

Forward: A[i=M, j=K] -> (thread=(j//8)*16+i, local=j%8).

tilelang.intrinsics.wmma_layout.thread_id_shared_access_32x8_to_16x16_layout_A_gfx12(thread_id, local_id)¶

Reverse: (thread, local) -> (i=M=thread%16, j=K=(thread//16)*8+local).

tilelang.intrinsics.wmma_layout.shared_16x16_to_local_32x8_layout_A_colmajor_gfx12(i, j)¶

Forward: A_T[i=K, j=M] -> (thread=(i//8)*16+j, local=i%8).

tilelang.intrinsics.wmma_layout.thread_id_shared_access_32x8_to_16x16_layout_A_colmajor_gfx12(thread_id, local_id)¶

Reverse: (thread, local) -> (i=K=(thread//16)*8+local, j=M=thread%16).

tilelang.intrinsics.wmma_layout.shared_16x16_to_local_32x8_layout_B_gfx12(i, j)¶

Forward: B[i=K, j=N] -> (thread=(i//8)*16+j, local=i%8).

tilelang.intrinsics.wmma_layout.thread_id_shared_access_32x8_to_16x16_layout_B_gfx12(thread_id, local_id)¶

Reverse: (thread, local) -> (i=K=(thread//16)*8+local, j=N=thread%16).

tilelang.intrinsics.wmma_layout.shared_16x16_to_local_32x8_layout_B_colmajor_gfx12(i, j)¶

Forward: B_T[i=N, j=K] -> (thread=(j//8)*16+i, local=j%8).

tilelang.intrinsics.wmma_layout.thread_id_shared_access_32x8_to_16x16_layout_B_colmajor_gfx12(thread_id, local_id)¶

Reverse: (thread, local) -> (i=N=thread%16, j=K=(thread//16)*8+local).

tilelang.intrinsics.wmma_layout.shared_16x16_to_local_32x8_layout_C_gfx12(i, j)¶

Forward: D[i=M, j=N] -> (thread=(i//8)*16+j, local=i%8).

tilelang.intrinsics.wmma_layout.thread_id_shared_access_32x8_to_16x16_layout_C_gfx12(thread_id, local_id)¶

Reverse: (thread, local) -> (i=M=(thread//16)*8+local, j=N=thread%16).

tilelang.intrinsics.wmma_layout.wmma_store_index_map_gfx12(thread_id, local_id)¶

(thread, local) -> (M, N) in D. Hardware D layout.

tilelang.intrinsics.wmma_layout.shared_16x16_to_local_32x16_layout_A_gfx11(i, j)¶

Forward: A[i=M, j=K] -> (thread=i, local=j%16). ATTN: Here we only reflect (i, j) to the lower-half-lane of threads in a warp.

tilelang.intrinsics.wmma_layout.thread_id_shared_access_32x16_to_16x16_layout_A_gfx11(thread_id, local_id)¶

Reverse: (thread, local) -> (i=M=thread%16, j=K=local)

tilelang.intrinsics.wmma_layout.shared_16x16_to_local_32x16_layout_A_colmajor_gfx11(i, j)¶

Forward: A_T[i=K, j=M] -> (thread=M, local=K%16). ATTN: Here we only reflect (i, j) to the lower-half-lane of threads in a warp.

tilelang.intrinsics.wmma_layout.thread_id_shared_access_32x16_to_16x16_layout_A_colmajor_gfx11(thread_id, local_id)¶

Reverse: (thread, local) -> (i=K=local, j=M=thread%16)

tilelang.intrinsics.wmma_layout.shared_16x16_to_local_32x16_layout_B_gfx11(i, j)¶

Forward: B[i=K, j=N] -> (thread=N, local=K%16). ATTN: Here we only reflect (i, j) to the lower-half-lane of threads in a warp.

tilelang.intrinsics.wmma_layout.thread_id_shared_access_32x16_to_16x16_layout_B_gfx11(thread_id, local_id)¶

Reverse: (thread, local) -> (i=K=local, j=N=thread%16)

tilelang.intrinsics.wmma_layout.shared_16x16_to_local_32x16_layout_B_colmajor_gfx11(i, j)¶

Forward: B_T[i=N, j=K] -> (thread=i, local=j%16). ATTN: Here we only reflect (i, j) to the lower-half-lane of threads in a warp.

tilelang.intrinsics.wmma_layout.thread_id_shared_access_32x16_to_16x16_layout_B_colmajor_gfx11(thread_id, local_id)¶

Reverse: (thread, local) -> (j=K=local, i=N=thread%16)

tilelang.intrinsics.wmma_layout.shared_16x16_to_local_32x8_layout_C_gfx11(i, j)¶

Forward: D[i=M, j=N] -> (thread=(i%2)*16+j, local=i//2).

tilelang.intrinsics.wmma_layout.thread_id_shared_access_32x8_to_16x16_layout_C_gfx11(thread_id, local_id)¶

Reverse: (thread, local) -> (i=M=(thread//16)+local*2, j=N=thread%16)

tilelang.intrinsics.wmma_layout.wmma_store_index_map_gfx11(thread_id, local_id)¶

(thread, local) -> (M, N) in D. Hardware D layout.

tilelang.intrinsics.wmma_layout.fragment_forward_A_gfx11(i, j, rep)¶

Replicated fragment forward map for gfx11 A.

The canonical owner lives in the lower half-wave and rep selects whether the logical element is materialized in the lower or upper half-wave copy.

tilelang.intrinsics.wmma_layout.fragment_forward_A_colmajor_gfx11(i, j, rep)¶

Replicated fragment forward map for gfx11 transposed A.

tilelang.intrinsics.wmma_layout.fragment_forward_B_gfx11(i, j, rep)¶

Replicated fragment forward map for gfx11 B.

tilelang.intrinsics.wmma_layout.fragment_forward_B_colmajor_gfx11(i, j, rep)¶

Replicated fragment forward map for gfx11 transposed B.

tilelang.intrinsics.wmma_layout.get_wmma_a_layout_funcs(rdna_gen, transposed)¶

Return (forward_map, reverse_map) for A layout.

Parameters:
  • rdna_gen (int)

  • transposed (bool)

tilelang.intrinsics.wmma_layout.get_wmma_b_layout_funcs(rdna_gen, transposed)¶

Return (forward_map, reverse_map) for B layout.

Parameters:
  • rdna_gen (int)

  • transposed (bool)

tilelang.intrinsics.wmma_layout.get_wmma_c_layout_funcs(rdna_gen)¶

Return (forward_map, reverse_map) for C/D layout.

Parameters:

rdna_gen (int)

tilelang.intrinsics.wmma_layout.get_wmma_store_index_map_func(rdna_gen)¶

Return the (thread_id, local_id) -> (row, col) store map.

Parameters:

rdna_gen (int)

tilelang.intrinsics.wmma_layout.get_wmma_a_fragment_forward_func(rdna_gen, transposed)¶

Return the fragment forward function for A layout.

Parameters:
  • rdna_gen (int)

  • transposed (bool)

tilelang.intrinsics.wmma_layout.get_wmma_b_fragment_forward_func(rdna_gen, transposed)¶

Return the fragment forward function for B layout.

Parameters:
  • rdna_gen (int)

  • transposed (bool)

tilelang.intrinsics.wmma_layout.get_wmma_fragment_replicate_count(rdna_gen)¶

Return the fragment replicate count used for logical one-to-many owners.

Parameters:

rdna_gen (int)