Skip to content

Commit 66340aa

Browse files
committed
Add unit tests and refactor duplicated code
1 parent afd3929 commit 66340aa

File tree

3 files changed

+71
-132
lines changed

3 files changed

+71
-132
lines changed

xla/backends/gpu/codegen/triton/fusion_emitter_device_legacy_test.cc

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4202,6 +4202,36 @@ ENTRY main {
42024202
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1.0, /*arel=*/1e-3}));
42034203
}
42044204

4205+
TEST_F(TritonTest, FP8ToFP8EndToEnd) {
4206+
if (!GetCudaComputeCapability().IsAtLeastHopper()) {
4207+
GTEST_SKIP() << "Doesn't pass on pre-Hopper GPUs.";
4208+
}
4209+
4210+
const std::string hlo_text = R"(
4211+
HloModule t
4212+
4213+
triton_dot {
4214+
parameter_0 = f8e5m2[32,32]{1,0} parameter(0)
4215+
parameter_1 = f8e4m3fn[32,32]{1,0} parameter(1)
4216+
convert = f8e4m3fn[32,32]{1,0} convert(parameter_0)
4217+
ROOT dot = f32[32,32]{1,0} dot(convert, parameter_1),
4218+
lhs_contracting_dims={1}, rhs_contracting_dims={1}
4219+
}
4220+
4221+
ENTRY main {
4222+
parameter_0 = f8e5m2[32,32]{1,0} parameter(0)
4223+
parameter_1 = f8e4m3fn[32,32]{1,0} parameter(1)
4224+
ROOT gemm_fusion_dot = f32[32,32]{1,0} fusion(parameter_0, parameter_1),
4225+
kind=kCustom, calls=triton_dot,
4226+
backend_config={
4227+
"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":
4228+
{"block_m":"32","block_n":"32","block_k":"32","split_k":"1",
4229+
"num_stages":"1","num_warps":"4","num_ctas":"1"}}}
4230+
})";
4231+
4232+
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{/*aabs=*/1.0, /*arel=*/1e-3}));
4233+
}
4234+
42054235
// Test PreventMmaV3LoopUnrolling pass in order to keep compile time low.
42064236
// See b/344841434.
42074237
TEST_F(TritonGemmTest, TestPreventMMAV3LoopUnrolling) {

xla/backends/gpu/codegen/triton/fusion_emitter_device_test.cc

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1766,6 +1766,39 @@ ENTRY entry_computation {
17661766
EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), kExactMatch));
17671767
}
17681768

1769+
TEST_F(TritonEmitterTest, FP8ToFP8EndToEnd) {
1770+
if (auto cc =
1771+
std::get_if<se::CudaComputeCapability>(&GpuComputeCapability())) {
1772+
if (!cc->IsAtLeastHopper()) {
1773+
GTEST_SKIP() << "Doesn't pass on pre-Hopper GPUs.";
1774+
}
1775+
}
1776+
1777+
const std::string hlo_text = R"(
1778+
HloModule t
1779+
1780+
triton_dot {
1781+
parameter_0 = f8e5m2[32,32]{1,0} parameter(0)
1782+
parameter_1 = f8e4m3fn[32,32]{1,0} parameter(1)
1783+
convert = f8e4m3fn[32,32]{1,0} convert(parameter_0)
1784+
ROOT dot = f32[32,32]{1,0} dot(convert, parameter_1),
1785+
lhs_contracting_dims={1}, rhs_contracting_dims={1}
1786+
}
1787+
1788+
ENTRY main {
1789+
parameter_0 = f8e5m2[32,32]{1,0} parameter(0)
1790+
parameter_1 = f8e4m3fn[32,32]{1,0} parameter(1)
1791+
ROOT gemm_fusion_dot = f32[32,32]{1,0} fusion(parameter_0, parameter_1),
1792+
kind=kCustom, calls=triton_dot,
1793+
backend_config={
1794+
"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":
1795+
{"block_m":"32","block_n":"32","block_k":"32","split_k":"1",
1796+
"num_stages":"1","num_warps":"4","num_ctas":"1"}}}
1797+
})";
1798+
1799+
EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_text, ErrorSpec{/*aabs=*/1.0, /*arel=*/1e-3}));
1800+
}
1801+
17691802
TEST_F(TritonEmitterTest, SingleTileDotWithNestedFusionsIsEmittedCorrectly) {
17701803
// Simplest case when everything fits into one tile that is useful for
17711804
// debugging. This also tests support for empty nested fusions.

xla/backends/gpu/codegen/triton/fusion_emitter_legacy_matmul.cc

Lines changed: 8 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -227,130 +227,6 @@ bool IsFp8Type(Type t) {
227227
mlir::Float8E4M3B11FNUZType>(t);
228228
}
229229

230-
Value Cast(EmitterLocOpBuilder b, Value value, Type dst_element_ty) {
231-
Type src_ty = value.getType();
232-
Type src_element_ty = src_ty;
233-
Type fp16_ty = b.getF16Type();
234-
Type fp32_ty = b.getF32Type();
235-
Type dst_ty = dst_element_ty;
236-
if (auto src_shaped_ty = mlir::dyn_cast<ShapedType>(src_ty)) {
237-
src_element_ty = src_shaped_ty.getElementType();
238-
dst_ty = src_shaped_ty.clone(src_shaped_ty.getShape(), dst_element_ty);
239-
fp16_ty = src_shaped_ty.clone(src_shaped_ty.getShape(), b.getF16Type());
240-
fp32_ty = src_shaped_ty.clone(src_shaped_ty.getShape(), b.getF32Type());
241-
}
242-
if (src_ty == dst_ty) {
243-
return value;
244-
}
245-
246-
// All operations on bf16 are done through f32.
247-
if (src_element_ty.isBF16()) {
248-
return Cast(b, b.create<ma::ExtFOp>(fp32_ty, value), dst_element_ty);
249-
}
250-
if (dst_element_ty.isBF16()) {
251-
// S8 -> BF16 is directly supported and doesn't need to go through f32.
252-
if (!src_element_ty.isInteger(8)) {
253-
return b.create<ma::TruncFOp>(dst_ty, Cast(b, value, b.getF32Type()));
254-
}
255-
}
256-
257-
// float => float
258-
auto src_fp_element_ty = mlir::dyn_cast<mlir::FloatType>(src_element_ty);
259-
auto dst_fp_element_ty = mlir::dyn_cast<mlir::FloatType>(dst_element_ty);
260-
if (src_fp_element_ty && dst_fp_element_ty) {
261-
// F8 <-> FP16, BF16, FP32, FP64 need to be handled via Triton's tt.fp_to_fp
262-
// because LLVM doesn't support casts from/to FP8.
263-
// TODO(b/266862493): Add end-to-end test once FP8 support lands in XLA as
264-
// we can't test the code below without patching the feature.
265-
if (IsFp8Type(src_element_ty) && !IsFp8Type(dst_element_ty)) {
266-
return b.create<mt::FpToFpOp>(dst_ty, value);
267-
}
268-
if (IsFp8Type(dst_element_ty) && !IsFp8Type(src_element_ty)) {
269-
return b.create<mt::FpToFpOp>(
270-
dst_ty, value,
271-
mt::RoundingModeAttr::get(b.getContext(), mt::RoundingMode::RTNE));
272-
}
273-
if (IsFp8Type(src_element_ty) && IsFp8Type(dst_element_ty)) {
274-
// FP8 <-> FP8 conversion needs to go through FP16
275-
auto fp16_value = b.create<mt::FpToFpOp>(fp16_ty, value);
276-
return b.create<mt::FpToFpOp>(
277-
dst_ty, fp16_value,
278-
mt::RoundingModeAttr::get(b.getContext(), mt::RoundingMode::RTNE));
279-
}
280-
281-
if (src_fp_element_ty.getFPMantissaWidth() >
282-
dst_fp_element_ty.getFPMantissaWidth()) {
283-
return b.create<ma::TruncFOp>(dst_ty, value);
284-
} else {
285-
return b.create<ma::ExtFOp>(dst_ty, value);
286-
}
287-
}
288-
// int => int
289-
if (mlir::isa<mlir::IntegerType>(src_element_ty) &&
290-
mlir::isa<mlir::IntegerType>(dst_element_ty)) {
291-
if (src_element_ty.getIntOrFloatBitWidth() <
292-
dst_element_ty.getIntOrFloatBitWidth()) {
293-
if (src_element_ty.isInteger(1)) {
294-
return b.create<ma::ExtUIOp>(dst_ty, value);
295-
}
296-
return b.create<ma::ExtSIOp>(dst_ty, value);
297-
}
298-
return b.create<ma::TruncIOp>(dst_ty, value);
299-
}
300-
// int => float
301-
if (mlir::isa<mlir::IntegerType>(src_element_ty) && dst_fp_element_ty) {
302-
// TODO(b/266862493): Support unsigned integer types.
303-
if (src_element_ty.isInteger(1)) {
304-
return b.create<ma::UIToFPOp>(dst_ty, value);
305-
}
306-
return b.create<ma::SIToFPOp>(dst_ty, value);
307-
}
308-
// float => int
309-
if (src_fp_element_ty && mlir::isa<mlir::IntegerType>(dst_element_ty)) {
310-
if (dst_element_ty.isInteger(1)) {
311-
return b.create<ma::CmpFOp>(ma::CmpFPredicate::UNE, value,
312-
ZerosLike(b, value));
313-
}
314-
// TODO(b/266862493): Support unsigned integer types.
315-
// The current logic handles signed integer types only. Additional handling
316-
// is needed for unsigned integer types.
317-
auto cst_int = [&](EmitterLocOpBuilder b, int64_t x) {
318-
if (auto src_shaped_ty = mlir::dyn_cast<ShapedType>(src_ty)) {
319-
return CreateConst(b, dst_element_ty, x, src_shaped_ty.getShape());
320-
} else {
321-
return CreateConst(b, dst_element_ty, x);
322-
}
323-
};
324-
auto cst_float = [&](EmitterLocOpBuilder b, int64_t x) {
325-
if (auto src_shaped_ty = mlir::dyn_cast<ShapedType>(src_ty)) {
326-
return CreateConst(b, src_fp_element_ty, x, src_shaped_ty.getShape());
327-
} else {
328-
return CreateConst(b, src_fp_element_ty, x);
329-
}
330-
};
331-
auto fptosi = b.create<ma::FPToSIOp>(dst_ty, value);
332-
int64_t min = llvm::minIntN(dst_element_ty.getIntOrFloatBitWidth());
333-
int64_t max = llvm::maxIntN(dst_element_ty.getIntOrFloatBitWidth());
334-
335-
// value <= static_cast<float>(INT_MIN) ? INT_MIN : ...
336-
auto clamped = b.create<ma::SelectOp>(
337-
b.create<ma::CmpFOp>(ma::CmpFPredicate::OLE, value, cst_float(b, min)),
338-
cst_int(b, min), fptosi);
339-
// value >= static_cast<float>(INT_MAX) ? INT_MAX : ...
340-
clamped = b.create<ma::SelectOp>(
341-
b.create<ma::CmpFOp>(ma::CmpFPredicate::OGE, value, cst_float(b, max)),
342-
cst_int(b, max), clamped);
343-
// isnan(value) ? 0 : ...
344-
return b.create<ma::SelectOp>(
345-
b.create<ma::CmpFOp>(ma::CmpFPredicate::UNO, value, value),
346-
cst_int(b, 0), clamped);
347-
}
348-
349-
LOG(FATAL) << "Type conversion not supported: "
350-
<< llvm_ir::DumpToString(src_element_ty) << " -> "
351-
<< llvm_ir::DumpToString(dst_element_ty);
352-
}
353-
354230
Value Subtract(EmitterLocOpBuilder b, ValueRange values) {
355231
if (mlir::isa<mlir::IntegerType>(mlir::getElementTypeOrSelf(values[0]))) {
356232
return b.create<ma::SubIOp>(values[0], values[1]);
@@ -457,7 +333,7 @@ absl::StatusOr<Value> EmitElementwise(EmitterLocOpBuilder b,
457333
case HloOpcode::kConvert: {
458334
TF_ASSIGN_OR_RETURN(Type dst_ty,
459335
TritonType(b, hlo.shape().element_type()));
460-
return Cast(b, inputs[0], dst_ty);
336+
return triton::Cast(b, inputs[0], dst_ty);
461337
}
462338
case HloOpcode::kAdd:
463339
if (is_integer) {
@@ -670,7 +546,7 @@ absl::StatusOr<Value> EmitScope(
670546
if (hlo->opcode() == HloOpcode::kConvert &&
671547
hlo->operand(0)->shape().element_type() == S4) {
672548
Value unpacked;
673-
unpacked = Cast(b, values[hlo->operand(0)], b.getI8Type());
549+
unpacked = triton::Cast(b, values[hlo->operand(0)], b.getI8Type());
674550
std::vector<Value> operands({unpacked});
675551
TF_ASSIGN_OR_RETURN(result, EmitElementwise(b, libdevice_path,
676552
device_info, *hlo, operands));
@@ -826,7 +702,7 @@ ma::ConstantOp Cst64(EmitterLocOpBuilder b, int64_t v) {
826702
}
827703

828704
Value RoundToBF16(EmitterLocOpBuilder b, Value input) {
829-
return Cast(b, input, b.getBF16Type());
705+
return triton::Cast(b, input, b.getBF16Type());
830706
};
831707

832708
/*static*/ absl::StatusOr<MatMulDims> MatMulDims::Create(
@@ -1487,7 +1363,7 @@ class MatMulEmitterHelper {
14871363
"64 bit dynamic-slice indices are not supported yet.");
14881364
}
14891365
majormost_dim_start_index_val =
1490-
Cast(b, majormost_dim_start_index_val, b.getI32Type());
1366+
triton::Cast(b, majormost_dim_start_index_val, b.getI32Type());
14911367
majormost_dim_start_index_val =
14921368
b.create<ma::MaxSIOp>(majormost_dim_start_index_val, Cst32(b, 0));
14931369
majormost_dim_start_index_val =
@@ -2049,7 +1925,7 @@ class IterableInput {
20491925
Value param_value = EmitParameterLoad(b, args.front(), boundary_checks_);
20501926
if (type_ != storage_type_) {
20511927
// For example cast i8 to i1.
2052-
param_value = Cast(b, param_value, type_);
1928+
param_value = triton::Cast(b, param_value, type_);
20531929
}
20541930
return param_value;
20551931
}
@@ -2175,10 +2051,10 @@ Value EmitRegularMatmul(EmitterLocOpBuilder& b, Value lhs, Value rhs, Value acc,
21752051
if (dot_instr->precision_config().algorithm() ==
21762052
PrecisionConfig::ALG_DOT_BF16_BF16_F32) {
21772053
if (dot_instr->operand(0)->shape().element_type() == F32) {
2178-
lhs = Cast(b, lhs, b.getBF16Type());
2054+
lhs = triton::Cast(b, lhs, b.getBF16Type());
21792055
}
21802056
if (dot_instr->operand(1)->shape().element_type() == F32) {
2181-
rhs = Cast(b, rhs, b.getBF16Type());
2057+
rhs = triton::Cast(b, rhs, b.getBF16Type());
21822058
}
21832059
}
21842060

@@ -2372,7 +2248,7 @@ absl::StatusOr<std::optional<stream_executor::gpu::TmaMetadata>> EmitMatMul(
23722248
absl::flat_hash_map<const HloInstruction*, Value> values_out;
23732249
TF_ASSIGN_OR_RETURN(Type acc_final_ty,
23742250
TritonType(b, dot_instr->shape().element_type()));
2375-
values_out[dot_instr] = Cast(b, acc_final, acc_final_ty);
2251+
values_out[dot_instr] = triton::Cast(b, acc_final, acc_final_ty);
23762252

23772253
// Emit the output scope.
23782254
if (std::vector<const HloInstruction*> to_emit =

0 commit comments

Comments
 (0)