Skip to content

Commit 06ec46f

Browse files
kasper0406Google-ML-Automation
authored andcommitted
PR #24114: Triton/Nvidia: Fix fused fp8 <-> fp8 conversions
Imported from GitHub PR #24114 Converting FP8 <-> FP8 fails because the Triton compiler does not support it. The proposed fix will make the conversion go through FP16. Two questions: 1) Are there any better approaches of solving this? 2) I could not find a place to put unit tests for this, and in the code there is a comment saying: ``` // TODO(b/266862493): Add end-to-end test once FP8 support lands in XLA as // we can't test the code below without patching the feature. ``` Wondering if there is a place where I can add a test? ### Details When converting FP8 types, the XLA compiler emits a `fp_to_fp` Triton instruction. If the source type is FP8, no rounding strategy is specified. Concretely, this causes the following Triton to be emitted: <details> <summary> <code>%24 = tt.fp_to_fp %20 : tensor<32x64xf8E5M2> -> tensor<32x64xf8E4M3FN></code> </summary> ``` module { tt.func @gemm_fusion_dot_320_impl(%arg0: !tt.ptr<f8E4M3FN> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f8E5M2> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f8E4M3FN> {tt.divisibility = 16 : i32}) { %cst = arith.constant dense<0.000000e+00> : tensor<64x64xf8E4M3FN> %cst_0 = arith.constant dense<0.000000e+00> : tensor<32x64xf8E4M3FN> %c90_i32 = arith.constant 90 : i32 %c32000_i64 = arith.constant 32000 : i64 %c64_i32 = arith.constant 64 : i32 %c90_i64 = arith.constant 90 : i64 %c768_i64 = arith.constant 768 : i64 %c0_i32 = arith.constant 0 : i32 %c1_i64 = arith.constant 1 : i64 %c32_i32 = arith.constant 32 : i32 %c24_i32 = arith.constant 24 : i32 %c8_i32 = arith.constant 8 : i32 %c4000_i32 = arith.constant 4000 : i32 %cst_1 = arith.constant dense<0.000000e+00> : tensor<32x64xf32> %0 = tt.get_program_id x : i32 %1 = arith.divsi %0, %c4000_i32 : i32 %2 = arith.muli %1, %c8_i32 : i32 %3 = arith.subi %c24_i32, %2 : i32 %4 = arith.cmpi slt, %3, %c8_i32 : i32 %5 = arith.select %4, %3, %c8_i32 : i32 %6 = arith.remsi %0, %5 : i32 %7 = arith.addi %2, %6 : i32 %8 = arith.remsi %0, %c4000_i32 : i32 %9 = arith.divsi %8, %5 : i32 %10 = arith.muli %7, %c32_i32 : i32 %11 = tt.make_tensor_ptr %arg1, [%c768_i64, %c90_i64], [%c1_i64, %c768_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x64xf8E5M2>> %12 = tt.advance %11, [%10, %c0_i32] : <tensor<32x64xf8E5M2>> %13 = arith.muli %9, %c64_i32 : i32 %14 = tt.make_tensor_ptr %arg0, [%c90_i64, %c32000_i64], [%c1_i64, %c90_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xf8E4M3FN>> %15 = tt.advance %14, [%c0_i32, %13] : <tensor<64x64xf8E4M3FN>> %16:3 = scf.for %arg3 = %c0_i32 to %c90_i32 step %c64_i32 iter_args(%arg4 = %12, %arg5 = %15, %arg6 = %cst_1) -> (!tt.ptr<tensor<32x64xf8E5M2>>, !tt.ptr<tensor<64x64xf8E4M3FN>>, tensor<32x64xf32>) : i32 { %20 = tt.load %arg4 {boundaryCheck = array<i32: 1>, padding = 1 : i32} : !tt.ptr<tensor<32x64xf8E5M2>> %21 = tt.advance %arg4, [%c0_i32, %c64_i32] : <tensor<32x64xf8E5M2>> %22 = tt.load %arg5 {boundaryCheck = array<i32: 0>, padding = 1 : i32} : !tt.ptr<tensor<64x64xf8E4M3FN>> %23 = tt.advance %arg5, [%c64_i32, %c0_i32] : <tensor<64x64xf8E4M3FN>> %24 = tt.fp_to_fp %20 : tensor<32x64xf8E5M2> -> tensor<32x64xf8E4M3FN> %25 = arith.subi %c90_i32, %arg3 : i32 %26 = arith.cmpi slt, %25, %c64_i32 : i32 %27 = scf.if %26 -> (tensor<32x64xf8E4M3FN>) { %30 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> %31 = tt.expand_dims %30 {axis = 0 : i32} : tensor<64xi32> -> tensor<1x64xi32> %32 = tt.splat %25 : i32 -> tensor<1x64xi32> %33 = arith.cmpi slt, %31, %32 : tensor<1x64xi32> %34 = tt.broadcast %33 : tensor<1x64xi1> -> tensor<32x64xi1> %35 = arith.select %34, %24, %cst_0 : tensor<32x64xi1>, tensor<32x64xf8E4M3FN> scf.yield %35 : tensor<32x64xf8E4M3FN> } else { scf.yield %24 : tensor<32x64xf8E4M3FN> } %28 = scf.if %26 -> (tensor<64x64xf8E4M3FN>) { %30 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> %31 = tt.expand_dims %30 {axis = 1 : i32} : tensor<64xi32> -> tensor<64x1xi32> %32 = tt.splat %25 : i32 -> tensor<64x1xi32> %33 = arith.cmpi slt, %31, %32 : tensor<64x1xi32> %34 = tt.broadcast %33 : tensor<64x1xi1> -> tensor<64x64xi1> %35 = arith.select %34, %22, %cst : tensor<64x64xi1>, tensor<64x64xf8E4M3FN> scf.yield %35 : tensor<64x64xf8E4M3FN> } else { scf.yield %22 : tensor<64x64xf8E4M3FN> } %29 = tt.dot %27, %28, %arg6, inputPrecision = tf32 {maxNumImpreciseAcc = 2147483647 : i32} : tensor<32x64xf8E4M3FN> * tensor<64x64xf8E4M3FN> -> tensor<32x64xf32> scf.yield %21, %23, %29 : !tt.ptr<tensor<32x64xf8E5M2>>, !tt.ptr<tensor<64x64xf8E4M3FN>>, tensor<32x64xf32> } %17 = tt.fp_to_fp %16#2, rounding = rtne : tensor<32x64xf32> -> tensor<32x64xf8E4M3FN> %18 = tt.make_tensor_ptr %arg2, [%c768_i64, %c32000_i64], [%c1_i64, %c768_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x64xf8E4M3FN>> %19 = tt.advance %18, [%10, %13] : <tensor<32x64xf8E4M3FN>> tt.store %19, %17 : !tt.ptr<tensor<32x64xf8E4M3FN>> tt.return } } ``` </details> Which leads to a failing assertion: ``` #0 0x000073413786d9fc in pthread_kill () from /lib/x86_64-linux-gnu/libc.so.6 #1 0x0000734137819476 in raise () from /lib/x86_64-linux-gnu/libc.so.6 #2 0x00007341377ff7f3 in abort () from /lib/x86_64-linux-gnu/libc.so.6 #3 0x00007341377ff71b in ?? () from /lib/x86_64-linux-gnu/libc.so.6 #4 0x0000734137810e96 in __assert_fail () from /lib/x86_64-linux-gnu/libc.so.6 #5 0x000057d936b1777b in mlir::triton::gpu::(anonymous namespace)::FpToFpOpConversion::createDestOps (this=0x733d08425cc0, op=..., adaptor=..., rewriter=..., elemTy=..., operands=..., loc=...) at external/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp:500 #6 0x000057d936b17195 in mlir::triton::gpu::ElementwiseOpConversionBase<mlir::triton::FpToFpOp, mlir::triton::gpu::(anonymous namespace)::FpToFpOpConversion>::matchAndRewrite (this=0x733d08425cc0, op=..., adaptor=..., rewriter=...) at external/triton/include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h:188 [...] #29 0x000057d93fa6cade in mlir::PassManager::run (this=0x733e80fba158, op=0x733d080bbc20) at external/llvm-project/mlir/lib/Pass/Pass.cpp:885 #30 0x000057d9363f6b1b in xla::gpu::CompileTritonToLLVM (hlo_config=..., hlo_module_name="gemm_fusion_dot.320", device_info=..., block_level_parameters=..., triton_module=..., llvm_module=0x733d0816d6a0, mlir_context=..., is_xla_fusion=true, emit_kernel=true) at xla/backends/gpu/codegen/triton/fusion_emitter.cc:1627 #31 0x000057d9363f5a5d in xla::gpu::TritonWrapper (fn_name="gemm_fusion_dot_320_impl", fusion=0x733d080a31c0, cc=std::variant<stream_executor::CudaComputeCapability, stream_executor::RocmComputeCapability> [index 0] = {...}, device_info=..., block_level_parameters=..., llvm_module=0x733d0816d6a0, mlir_context=...) at xla/backends/gpu/codegen/triton/fusion_emitter.cc:1531 ``` However, this fails Triton compilation: * First it hits an assertion that the rounding strategy when the destination type is FP8 must be specified * Adding the rounding strategy, then goes on to another issue, that no methods for converting FP8 <-> FP8 are specified To work around the above two issues, I propose going through FP16 when both the source and destination types are FP8's. Copybara import of the project: -- afd3929 by Kasper Nielsen <[email protected]>: Fix fused fp8 <-> fp8 conversions -- 66340aa by Kasper Nielsen <[email protected]>: Add unit tests and refactor duplicated code -- 07ae307 by Kasper Nielsen <[email protected]>: Run clang-format Merging this change closes #24114 FUTURE_COPYBARA_INTEGRATE_REVIEW=#24114 from kasper0406:kn/fp8-conversion-fix 07ae307 PiperOrigin-RevId: 741162069
1 parent 6a48e7d commit 06ec46f

File tree

4 files changed

+83
-125
lines changed

4 files changed

+83
-125
lines changed

xla/backends/gpu/codegen/triton/emitter_helpers.cc

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,13 @@ bool IsFp8Type(Type t) {
126126
Value Cast(EmitterLocOpBuilder& b, Value value, Type dst_element_ty) {
127127
Type src_ty = value.getType();
128128
Type src_element_ty = src_ty;
129+
Type fp16_ty = b.getF16Type();
129130
Type fp32_ty = b.getF32Type();
130131
Type dst_ty = dst_element_ty;
131132
if (auto src_shaped_ty = mlir::dyn_cast<ShapedType>(src_ty)) {
132133
src_element_ty = src_shaped_ty.getElementType();
133134
dst_ty = src_shaped_ty.clone(src_shaped_ty.getShape(), dst_element_ty);
135+
fp16_ty = src_shaped_ty.clone(src_shaped_ty.getShape(), b.getF16Type());
134136
fp32_ty = src_shaped_ty.clone(src_shaped_ty.getShape(), b.getF32Type());
135137
}
136138
if (src_ty == dst_ty) {
@@ -156,14 +158,21 @@ Value Cast(EmitterLocOpBuilder& b, Value value, Type dst_element_ty) {
156158
// because LLVM doesn't support casts from/to FP8.
157159
// TODO(b/266862493): Add end-to-end test once FP8 support lands in XLA as
158160
// we can't test the code below without patching the feature.
159-
if (IsFp8Type(src_element_ty)) {
161+
if (IsFp8Type(src_element_ty) && !IsFp8Type(dst_element_ty)) {
160162
return b.create<mt::FpToFpOp>(dst_ty, value);
161163
}
162-
if (IsFp8Type(dst_element_ty)) {
164+
if (IsFp8Type(dst_element_ty) && !IsFp8Type(src_element_ty)) {
163165
return b.create<mt::FpToFpOp>(
164166
dst_ty, value,
165167
mt::RoundingModeAttr::get(b.getContext(), mt::RoundingMode::RTNE));
166168
}
169+
if (IsFp8Type(src_element_ty) && IsFp8Type(dst_element_ty)) {
170+
// FP8 <-> FP8 conversion needs to go through FP16
171+
auto fp16_value = b.create<mt::FpToFpOp>(fp16_ty, value);
172+
return b.create<mt::FpToFpOp>(
173+
dst_ty, fp16_value,
174+
mt::RoundingModeAttr::get(b.getContext(), mt::RoundingMode::RTNE));
175+
}
167176

168177
if (src_fp_element_ty.getFPMantissaWidth() >
169178
dst_fp_element_ty.getFPMantissaWidth()) {

xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_test.cc

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4202,6 +4202,36 @@ ENTRY main {
42024202
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1.0, /*arel=*/1e-3}));
42034203
}
42044204

4205+
TEST_F(TritonTest, FP8ToFP8EndToEnd) {
4206+
if (!GetCudaComputeCapability().IsAtLeastHopper()) {
4207+
GTEST_SKIP() << "Doesn't pass on pre-Hopper GPUs.";
4208+
}
4209+
4210+
const std::string hlo_text = R"(
4211+
HloModule t
4212+
4213+
triton_dot {
4214+
parameter_0 = f8e5m2[32,32]{1,0} parameter(0)
4215+
parameter_1 = f8e4m3fn[32,32]{1,0} parameter(1)
4216+
convert = f8e4m3fn[32,32]{1,0} convert(parameter_0)
4217+
ROOT dot = f32[32,32]{1,0} dot(convert, parameter_1),
4218+
lhs_contracting_dims={1}, rhs_contracting_dims={1}
4219+
}
4220+
4221+
ENTRY main {
4222+
parameter_0 = f8e5m2[32,32]{1,0} parameter(0)
4223+
parameter_1 = f8e4m3fn[32,32]{1,0} parameter(1)
4224+
ROOT gemm_fusion_dot = f32[32,32]{1,0} fusion(parameter_0, parameter_1),
4225+
kind=kCustom, calls=triton_dot,
4226+
backend_config={
4227+
"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":
4228+
{"block_m":"32","block_n":"32","block_k":"32","split_k":"1",
4229+
"num_stages":"1","num_warps":"4","num_ctas":"1"}}}
4230+
})";
4231+
4232+
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1.0, /*arel=*/1e-3}));
4233+
}
4234+
42054235
// Test PreventMmaV3LoopUnrolling pass in order to keep compile time low.
42064236
// See b/344841434.
42074237
TEST_F(TritonGemmTest, TestPreventMMAV3LoopUnrolling) {

xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1806,6 +1806,40 @@ ENTRY entry_computation {
18061806
EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), kExactMatch));
18071807
}
18081808

1809+
TEST_F(TritonEmitterTest, FP8ToFP8EndToEnd) {
1810+
if (auto cc =
1811+
std::get_if<se::CudaComputeCapability>(&GpuComputeCapability())) {
1812+
if (!cc->IsAtLeastHopper()) {
1813+
GTEST_SKIP() << "Doesn't pass on pre-Hopper GPUs.";
1814+
}
1815+
}
1816+
1817+
const std::string hlo_text = R"(
1818+
HloModule t
1819+
1820+
triton_dot {
1821+
parameter_0 = f8e5m2[32,32]{1,0} parameter(0)
1822+
parameter_1 = f8e4m3fn[32,32]{1,0} parameter(1)
1823+
convert = f8e4m3fn[32,32]{1,0} convert(parameter_0)
1824+
ROOT dot = f32[32,32]{1,0} dot(convert, parameter_1),
1825+
lhs_contracting_dims={1}, rhs_contracting_dims={1}
1826+
}
1827+
1828+
ENTRY main {
1829+
parameter_0 = f8e5m2[32,32]{1,0} parameter(0)
1830+
parameter_1 = f8e4m3fn[32,32]{1,0} parameter(1)
1831+
ROOT gemm_fusion_dot = f32[32,32]{1,0} fusion(parameter_0, parameter_1),
1832+
kind=kCustom, calls=triton_dot,
1833+
backend_config={
1834+
"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":
1835+
{"block_m":"32","block_n":"32","block_k":"32","split_k":"1",
1836+
"num_stages":"1","num_warps":"4","num_ctas":"1"}}}
1837+
})";
1838+
1839+
EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_text,
1840+
ErrorSpec{/*aabs=*/1.0, /*arel=*/1e-3}));
1841+
}
1842+
18091843
TEST_F(TritonEmitterTest, SingleTileDotWithNestedFusionsIsEmittedCorrectly) {
18101844
// Simplest case when everything fits into one tile that is useful for
18111845
// debugging. This also tests support for empty nested fusions.

xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc

Lines changed: 8 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -227,121 +227,6 @@ bool IsFp8Type(Type t) {
227227
mlir::Float8E4M3B11FNUZType>(t);
228228
}
229229

230-
Value Cast(EmitterLocOpBuilder b, Value value, Type dst_element_ty) {
231-
Type src_ty = value.getType();
232-
Type src_element_ty = src_ty;
233-
Type fp32_ty = b.getF32Type();
234-
Type dst_ty = dst_element_ty;
235-
if (auto src_shaped_ty = mlir::dyn_cast<ShapedType>(src_ty)) {
236-
src_element_ty = src_shaped_ty.getElementType();
237-
dst_ty = src_shaped_ty.clone(src_shaped_ty.getShape(), dst_element_ty);
238-
fp32_ty = src_shaped_ty.clone(src_shaped_ty.getShape(), b.getF32Type());
239-
}
240-
if (src_ty == dst_ty) {
241-
return value;
242-
}
243-
244-
// All operations on bf16 are done through f32.
245-
if (src_element_ty.isBF16()) {
246-
return Cast(b, b.create<ma::ExtFOp>(fp32_ty, value), dst_element_ty);
247-
}
248-
if (dst_element_ty.isBF16()) {
249-
// S8 -> BF16 is directly supported and doesn't need to go through f32.
250-
if (!src_element_ty.isInteger(8)) {
251-
return b.create<ma::TruncFOp>(dst_ty, Cast(b, value, b.getF32Type()));
252-
}
253-
}
254-
255-
// float => float
256-
auto src_fp_element_ty = mlir::dyn_cast<mlir::FloatType>(src_element_ty);
257-
auto dst_fp_element_ty = mlir::dyn_cast<mlir::FloatType>(dst_element_ty);
258-
if (src_fp_element_ty && dst_fp_element_ty) {
259-
// F8 <-> FP16, BF16, FP32, FP64 need to be handled via Triton's tt.fp_to_fp
260-
// because LLVM doesn't support casts from/to FP8.
261-
// TODO(b/266862493): Add end-to-end test once FP8 support lands in XLA as
262-
// we can't test the code below without patching the feature.
263-
if (IsFp8Type(src_element_ty)) {
264-
return b.create<mt::FpToFpOp>(dst_ty, value);
265-
}
266-
if (IsFp8Type(dst_element_ty)) {
267-
return b.create<mt::FpToFpOp>(
268-
dst_ty, value,
269-
mt::RoundingModeAttr::get(b.getContext(), mt::RoundingMode::RTNE));
270-
}
271-
272-
if (src_fp_element_ty.getFPMantissaWidth() >
273-
dst_fp_element_ty.getFPMantissaWidth()) {
274-
return b.create<ma::TruncFOp>(dst_ty, value);
275-
} else {
276-
return b.create<ma::ExtFOp>(dst_ty, value);
277-
}
278-
}
279-
// int => int
280-
if (mlir::isa<mlir::IntegerType>(src_element_ty) &&
281-
mlir::isa<mlir::IntegerType>(dst_element_ty)) {
282-
if (src_element_ty.getIntOrFloatBitWidth() <
283-
dst_element_ty.getIntOrFloatBitWidth()) {
284-
if (src_element_ty.isInteger(1)) {
285-
return b.create<ma::ExtUIOp>(dst_ty, value);
286-
}
287-
return b.create<ma::ExtSIOp>(dst_ty, value);
288-
}
289-
return b.create<ma::TruncIOp>(dst_ty, value);
290-
}
291-
// int => float
292-
if (mlir::isa<mlir::IntegerType>(src_element_ty) && dst_fp_element_ty) {
293-
// TODO(b/266862493): Support unsigned integer types.
294-
if (src_element_ty.isInteger(1)) {
295-
return b.create<ma::UIToFPOp>(dst_ty, value);
296-
}
297-
return b.create<ma::SIToFPOp>(dst_ty, value);
298-
}
299-
// float => int
300-
if (src_fp_element_ty && mlir::isa<mlir::IntegerType>(dst_element_ty)) {
301-
if (dst_element_ty.isInteger(1)) {
302-
return b.create<ma::CmpFOp>(ma::CmpFPredicate::UNE, value,
303-
ZerosLike(b, value));
304-
}
305-
// TODO(b/266862493): Support unsigned integer types.
306-
// The current logic handles signed integer types only. Additional handling
307-
// is needed for unsigned integer types.
308-
auto cst_int = [&](EmitterLocOpBuilder b, int64_t x) {
309-
if (auto src_shaped_ty = mlir::dyn_cast<ShapedType>(src_ty)) {
310-
return CreateConst(b, dst_element_ty, x, src_shaped_ty.getShape());
311-
} else {
312-
return CreateConst(b, dst_element_ty, x);
313-
}
314-
};
315-
auto cst_float = [&](EmitterLocOpBuilder b, int64_t x) {
316-
if (auto src_shaped_ty = mlir::dyn_cast<ShapedType>(src_ty)) {
317-
return CreateConst(b, src_fp_element_ty, x, src_shaped_ty.getShape());
318-
} else {
319-
return CreateConst(b, src_fp_element_ty, x);
320-
}
321-
};
322-
auto fptosi = b.create<ma::FPToSIOp>(dst_ty, value);
323-
int64_t min = llvm::minIntN(dst_element_ty.getIntOrFloatBitWidth());
324-
int64_t max = llvm::maxIntN(dst_element_ty.getIntOrFloatBitWidth());
325-
326-
// value <= static_cast<float>(INT_MIN) ? INT_MIN : ...
327-
auto clamped = b.create<ma::SelectOp>(
328-
b.create<ma::CmpFOp>(ma::CmpFPredicate::OLE, value, cst_float(b, min)),
329-
cst_int(b, min), fptosi);
330-
// value >= static_cast<float>(INT_MAX) ? INT_MAX : ...
331-
clamped = b.create<ma::SelectOp>(
332-
b.create<ma::CmpFOp>(ma::CmpFPredicate::OGE, value, cst_float(b, max)),
333-
cst_int(b, max), clamped);
334-
// isnan(value) ? 0 : ...
335-
return b.create<ma::SelectOp>(
336-
b.create<ma::CmpFOp>(ma::CmpFPredicate::UNO, value, value),
337-
cst_int(b, 0), clamped);
338-
}
339-
340-
LOG(FATAL) << "Type conversion not supported: "
341-
<< llvm_ir::DumpToString(src_element_ty) << " -> "
342-
<< llvm_ir::DumpToString(dst_element_ty);
343-
}
344-
345230
Value Subtract(EmitterLocOpBuilder b, ValueRange values) {
346231
if (mlir::isa<mlir::IntegerType>(mlir::getElementTypeOrSelf(values[0]))) {
347232
return b.create<ma::SubIOp>(values[0], values[1]);
@@ -448,7 +333,7 @@ absl::StatusOr<Value> EmitElementwise(EmitterLocOpBuilder b,
448333
case HloOpcode::kConvert: {
449334
TF_ASSIGN_OR_RETURN(Type dst_ty,
450335
TritonType(b, hlo.shape().element_type()));
451-
return Cast(b, inputs[0], dst_ty);
336+
return triton::Cast(b, inputs[0], dst_ty);
452337
}
453338
case HloOpcode::kAdd:
454339
if (is_integer) {
@@ -661,7 +546,7 @@ absl::StatusOr<Value> EmitScope(
661546
if (hlo->opcode() == HloOpcode::kConvert &&
662547
hlo->operand(0)->shape().element_type() == S4) {
663548
Value unpacked;
664-
unpacked = Cast(b, values[hlo->operand(0)], b.getI8Type());
549+
unpacked = triton::Cast(b, values[hlo->operand(0)], b.getI8Type());
665550
std::vector<Value> operands({unpacked});
666551
TF_ASSIGN_OR_RETURN(result, EmitElementwise(b, libdevice_path,
667552
device_info, *hlo, operands));
@@ -817,7 +702,7 @@ ma::ConstantOp Cst64(EmitterLocOpBuilder b, int64_t v) {
817702
}
818703

819704
Value RoundToBF16(EmitterLocOpBuilder b, Value input) {
820-
return Cast(b, input, b.getBF16Type());
705+
return triton::Cast(b, input, b.getBF16Type());
821706
};
822707

823708
/*static*/ absl::StatusOr<MatMulDims> MatMulDims::Create(
@@ -1480,7 +1365,7 @@ class MatMulEmitterHelper {
14801365
"64 bit dynamic-slice indices are not supported yet.");
14811366
}
14821367
majormost_dim_start_index_val =
1483-
Cast(b, majormost_dim_start_index_val, b.getI32Type());
1368+
triton::Cast(b, majormost_dim_start_index_val, b.getI32Type());
14841369
majormost_dim_start_index_val =
14851370
b.create<ma::MaxSIOp>(majormost_dim_start_index_val, Cst32(b, 0));
14861371
majormost_dim_start_index_val =
@@ -2041,7 +1926,7 @@ class IterableInput {
20411926
Value param_value = EmitParameterLoad(b, args.front(), boundary_checks_);
20421927
if (type_ != storage_type_) {
20431928
// For example cast i8 to i1.
2044-
param_value = Cast(b, param_value, type_);
1929+
param_value = triton::Cast(b, param_value, type_);
20451930
}
20461931
return param_value;
20471932
}
@@ -2167,10 +2052,10 @@ Value EmitRegularMatmul(EmitterLocOpBuilder& b, Value lhs, Value rhs, Value acc,
21672052
if (dot_instr->precision_config().algorithm() ==
21682053
PrecisionConfig::ALG_DOT_BF16_BF16_F32) {
21692054
if (dot_instr->operand(0)->shape().element_type() == F32) {
2170-
lhs = Cast(b, lhs, b.getBF16Type());
2055+
lhs = triton::Cast(b, lhs, b.getBF16Type());
21712056
}
21722057
if (dot_instr->operand(1)->shape().element_type() == F32) {
2173-
rhs = Cast(b, rhs, b.getBF16Type());
2058+
rhs = triton::Cast(b, rhs, b.getBF16Type());
21742059
}
21752060
}
21762061

@@ -2364,7 +2249,7 @@ absl::StatusOr<std::optional<stream_executor::gpu::TmaMetadata>> EmitMatMul(
23642249
absl::flat_hash_map<const HloInstruction*, Value> values_out;
23652250
TF_ASSIGN_OR_RETURN(Type acc_final_ty,
23662251
TritonType(b, dot_instr->shape().element_type()));
2367-
values_out[dot_instr] = Cast(b, acc_final, acc_final_ty);
2252+
values_out[dot_instr] = triton::Cast(b, acc_final, acc_final_ty);
23682253

23692254
// Emit the output scope.
23702255
if (std::vector<const HloInstruction*> to_emit =

0 commit comments

Comments
 (0)