tilelang.carver.roller.shape_inference.tir

類別

函式

region_exist_in_list(a, list)

walk_indice(expr)

get_analyzer_by_tir(block_analyzer, args)

Module Contents

class tilelang.carver.roller.shape_inference.tir.Statement(block_analyzer, block)
參數:

block (tvm.tir.schedule.schedule.BlockRV)

block_analyzer
block
dep_name
dependent_region
reverse_bound_inference
make_reverse(input_name, input_iter)
參數:
  • input_name (str)

  • input_iter (list[tvm.tir.PrimExpr])

class tilelang.carver.roller.shape_inference.tir.TensorDepNode(name)

For tensor dependency analysis.

name
add_next(node)
add_prev(node)
deduplicate(lst)
__str__()
__repr__()
class tilelang.carver.roller.shape_inference.tir.DependencyAnalysis(deps)
deps
name2dep
mapping
get_or_create_node(name)
traverse_dependencies(compute)
analyze()
print_dependencies()
find_path_from_source(start_name, target_name)

Finds the path (if it exists) from a starting node (source) to a target node. Returns the path as a list of nodes.

class tilelang.carver.roller.shape_inference.tir.InputShapeInference(deps)
參數:

deps (list[Statement])

deps
target_mapping
buffer_mapping
reduce_axes = []
dep_analysis
construct_dependency_target(targets)
參數:

targets (tuple[str])

infer(shape, rstep=None, targets=None)
參數:
  • shape (dict[str, list[tvm.arith.ConstIntBound]])

  • rstep (dict[str, int])

get_input_exprs(output_exprs)
tilelang.carver.roller.shape_inference.tir.region_exist_in_list(a, list)
回傳型別:

bool

tilelang.carver.roller.shape_inference.tir.walk_indice(expr)
tilelang.carver.roller.shape_inference.tir.get_analyzer_by_tir(block_analyzer, args)
回傳型別:

InputShapeInference