triton_related

2025-06-16

triton语义

Triton 在大多数情况下遵守 NumPy 的语义,但也有一些例外。

1 类型提升

类型提升 (Type Promotion) 是在不同数据类型的张量参与运算时发生的。对于与双下划线方法相关的二元运算和三元函数tl.where的最后两个参数,Triton 会自动将输入张量转换为一个通用的数据类型,这一转换遵循数据类型种类的层级顺序: {bool} < {integral dypes} < {floating point dtypes}:

  • 类型:如果一个张量的数据类型属于更高级的类型,则另一个张量将被提升至该数据类型,例如(int32, bfloat16) -> bfloat16

  • 宽度:如果两个张量的数据类型属于同一类,但其中一个张量具有更大的宽度,则另一个张量将被提升至此数据类型,例如:(float32, float16) -> float32

  • 上限:如果两个张量的宽度和符号相同,但数据类型不同,则它们都将被提升为下一个更大的数据类型,例如:(float16, bfloat16) -> float32
    • 如果两个张量的数据类型都是不同的 fp8 类型,它们将都被统一转换为 float16。
  • 无符号类型优先:在其他情况下(相同宽度,不同符号),它们将被提升为相应的无符号数据类型:(int32, uint32) -> uint32。

当涉及标量时,规则会有所不同。在此处,标量是指数值字面量、标记为tl.constexpr的变量或这些的组合。它们由 NumPy 标量表示,并具有boolintfloat等类型。

  • 当一个操作涉及张量和标量时
    • 如果标量的类型低于或等于张量的类型,则标量不会参与类型提升:(uint8, int) -> uint8。
    • 如果标量的类型高于张量,将选择最适合标量的数据类型,在 int32 < uint32 < int64 < uint64 中为整数选择,在 float32 < float64 中选择浮点数。然后,张量和标量都将提升为这种数据类型:(int16, 4.0) -> float32

2 broadcast

广播 (Broadcasting)允许对不同形状的张量进行操作,它会自动将张量的形状扩展至兼容的大小,且在此过程中无需复制数据。其遵循以下规则:

  • 如果其中一个张量的形状较短,则在左侧填充 1,直到两个张量的维数相同,例如:((3, 4), (5, 3, 4)) -> ((1, 3, 4), (5, 3, 4))
  • 如果两个维度相等,或者其中之一为 1,那么这两个维度是兼容的。维度为 1 的那个维度将会被扩展以匹配另一个张量的维度,例如:((1, 3, 4), (5, 3, 4)) -> ((5, 3, 4), (5, 3, 4))

3 与 NumPy 语义不同之处

为了提高效率,Triton 的整数除法运算符遵守 C 语义,而不是 Python 语义。

因此,混合符号的整数的int // int运算实现为 C 语言中的向零舍入,而不是 Python 中的向负无穷舍入。同样,取模运算符 int % int(定义为 a % b = a - b * (a // b))也遵守 C 语义,而不是 Python 语义。

所有输入都是标量的情况下,整数除法和取模运算遵循 Python 语义,这点可能会造成迷惑。