1 一些用于调试的实用函数
设置环境变量 TRITON_INTERPRET = 1
,可以像调试任何CPU程序一样调试Triton kernel。然后Triton在CPU上运行,但模拟它在GPU上运行。
check_tensors_gpu_ready
:断言所有张量在内存中是连续的;仅在非模拟情况下,断言所有张量在GPU上breakpoint_if
:根据 pids 的条件设置断点print_if
:根据pids的条件打印内容
import os
from IPython.core.debugger import set_trace
os.environ['TRITON_INTERPRET'] = '1' # needs to be set *before* triton is imported
def check_tensors_gpu_ready(*tensors):
"""检查所有张量是否在GPU上并且是连续的"""
for t in tensors:
assert t.is_contiguous, "A tensor is not contiguous" # 断言张量是连续的
if not os.environ.get('TRITON_INTERPRET') == '1': assert t.is_cuda, "A tensor is not on cuda" # 如果不是模拟模式,断言张量在GPU上
def test_pid_conds(conds, pid_0=[0], pid_1=[0], pid_2=[0]):
"""测试pid条件是否满足
例如:
'=0' 检查pid_0 == 0
',>1' 检查pid_1 > 1
'>1,=0' 检查pid_0 > 1 且 pid_1 == 0
"""
pids = pid_0[0], pid_1[0], pid_2[0] # 获取pid值
conds = conds.replace(' ','').split(',') # 去除空格并分割条件
for i, (cond, pid) in enumerate(zip(conds, pids)):
if cond=='': continue # 如果条件为空,跳过
op, threshold = cond[0], int(cond[1:]) # 获取操作符和阈值
if op not in ['<','>','>=','<=','=', '!=']: raise ValueError(f"Rules may only use these ops: '<','>','>=','<=','=', '!='. Invalid rule: '{condition}'.") # 检查操作符是否合法
op = '==' if op == '=' else op # 将'='替换为'=='
if not eval(f'{pid} {op} {threshold}'): return False # 评估条件是否满足
return True
assert test_pid_conds('') # 测试空条件
assert test_pid_conds('>0', [1], [1]) # 测试pid_0 > 0
assert not test_pid_conds('>0', [0], [1]) # 测试pid_0 > 0不满足
assert test_pid_conds('=0,=1', [0], [1], [0]) # 测试pid_0 = 0 且 pid_1 = 1
def breakpoint_if(conds, pid_0=[0], pid_1=[0], pid_2=[0]):
"""如果任何pid条件满足,停止kernel"""
if test_pid_conds(conds, pid_0, pid_1, pid_2): set_trace() # 如果条件满足,设置断点
def print_if(txt, conds, pid_0=[0], pid_1=[0], pid_2=[0]):
"""如果任何pid条件满足,打印txt"""
if test_pid_conds(conds, pid_0, pid_1, pid_2): print(txt) # 如果条件满足,打印文本
def cdiv(a,b):
"""计算a除以b的上限值"""
return (a + b - 1) // b # 计算a除以b的上限值
assert cdiv(10,2)==5 # 测试cdiv函数
assert cdiv(10,3)==4 # 测试cdiv函数
import torch
import triton
import triton.language as tl
print("troton version: ", triton.__version__)