@@ -227,121 +227,6 @@ bool IsFp8Type(Type t) {
227
227
mlir::Float8E4M3B11FNUZType>(t);
228
228
}
229
229
230
- Value Cast (EmitterLocOpBuilder b, Value value, Type dst_element_ty) {
231
- Type src_ty = value.getType ();
232
- Type src_element_ty = src_ty;
233
- Type fp32_ty = b.getF32Type ();
234
- Type dst_ty = dst_element_ty;
235
- if (auto src_shaped_ty = mlir::dyn_cast<ShapedType>(src_ty)) {
236
- src_element_ty = src_shaped_ty.getElementType ();
237
- dst_ty = src_shaped_ty.clone (src_shaped_ty.getShape (), dst_element_ty);
238
- fp32_ty = src_shaped_ty.clone (src_shaped_ty.getShape (), b.getF32Type ());
239
- }
240
- if (src_ty == dst_ty) {
241
- return value;
242
- }
243
-
244
- // All operations on bf16 are done through f32.
245
- if (src_element_ty.isBF16 ()) {
246
- return Cast (b, b.create <ma::ExtFOp>(fp32_ty, value), dst_element_ty);
247
- }
248
- if (dst_element_ty.isBF16 ()) {
249
- // S8 -> BF16 is directly supported and doesn't need to go through f32.
250
- if (!src_element_ty.isInteger (8 )) {
251
- return b.create <ma::TruncFOp>(dst_ty, Cast (b, value, b.getF32Type ()));
252
- }
253
- }
254
-
255
- // float => float
256
- auto src_fp_element_ty = mlir::dyn_cast<mlir::FloatType>(src_element_ty);
257
- auto dst_fp_element_ty = mlir::dyn_cast<mlir::FloatType>(dst_element_ty);
258
- if (src_fp_element_ty && dst_fp_element_ty) {
259
- // F8 <-> FP16, BF16, FP32, FP64 need to be handled via Triton's tt.fp_to_fp
260
- // because LLVM doesn't support casts from/to FP8.
261
- // TODO(b/266862493): Add end-to-end test once FP8 support lands in XLA as
262
- // we can't test the code below without patching the feature.
263
- if (IsFp8Type (src_element_ty)) {
264
- return b.create <mt::FpToFpOp>(dst_ty, value);
265
- }
266
- if (IsFp8Type (dst_element_ty)) {
267
- return b.create <mt::FpToFpOp>(
268
- dst_ty, value,
269
- mt::RoundingModeAttr::get (b.getContext (), mt::RoundingMode::RTNE));
270
- }
271
-
272
- if (src_fp_element_ty.getFPMantissaWidth () >
273
- dst_fp_element_ty.getFPMantissaWidth ()) {
274
- return b.create <ma::TruncFOp>(dst_ty, value);
275
- } else {
276
- return b.create <ma::ExtFOp>(dst_ty, value);
277
- }
278
- }
279
- // int => int
280
- if (mlir::isa<mlir::IntegerType>(src_element_ty) &&
281
- mlir::isa<mlir::IntegerType>(dst_element_ty)) {
282
- if (src_element_ty.getIntOrFloatBitWidth () <
283
- dst_element_ty.getIntOrFloatBitWidth ()) {
284
- if (src_element_ty.isInteger (1 )) {
285
- return b.create <ma::ExtUIOp>(dst_ty, value);
286
- }
287
- return b.create <ma::ExtSIOp>(dst_ty, value);
288
- }
289
- return b.create <ma::TruncIOp>(dst_ty, value);
290
- }
291
- // int => float
292
- if (mlir::isa<mlir::IntegerType>(src_element_ty) && dst_fp_element_ty) {
293
- // TODO(b/266862493): Support unsigned integer types.
294
- if (src_element_ty.isInteger (1 )) {
295
- return b.create <ma::UIToFPOp>(dst_ty, value);
296
- }
297
- return b.create <ma::SIToFPOp>(dst_ty, value);
298
- }
299
- // float => int
300
- if (src_fp_element_ty && mlir::isa<mlir::IntegerType>(dst_element_ty)) {
301
- if (dst_element_ty.isInteger (1 )) {
302
- return b.create <ma::CmpFOp>(ma::CmpFPredicate::UNE, value,
303
- ZerosLike (b, value));
304
- }
305
- // TODO(b/266862493): Support unsigned integer types.
306
- // The current logic handles signed integer types only. Additional handling
307
- // is needed for unsigned integer types.
308
- auto cst_int = [&](EmitterLocOpBuilder b, int64_t x) {
309
- if (auto src_shaped_ty = mlir::dyn_cast<ShapedType>(src_ty)) {
310
- return CreateConst (b, dst_element_ty, x, src_shaped_ty.getShape ());
311
- } else {
312
- return CreateConst (b, dst_element_ty, x);
313
- }
314
- };
315
- auto cst_float = [&](EmitterLocOpBuilder b, int64_t x) {
316
- if (auto src_shaped_ty = mlir::dyn_cast<ShapedType>(src_ty)) {
317
- return CreateConst (b, src_fp_element_ty, x, src_shaped_ty.getShape ());
318
- } else {
319
- return CreateConst (b, src_fp_element_ty, x);
320
- }
321
- };
322
- auto fptosi = b.create <ma::FPToSIOp>(dst_ty, value);
323
- int64_t min = llvm::minIntN (dst_element_ty.getIntOrFloatBitWidth ());
324
- int64_t max = llvm::maxIntN (dst_element_ty.getIntOrFloatBitWidth ());
325
-
326
- // value <= static_cast<float>(INT_MIN) ? INT_MIN : ...
327
- auto clamped = b.create <ma::SelectOp>(
328
- b.create <ma::CmpFOp>(ma::CmpFPredicate::OLE, value, cst_float (b, min)),
329
- cst_int (b, min), fptosi);
330
- // value >= static_cast<float>(INT_MAX) ? INT_MAX : ...
331
- clamped = b.create <ma::SelectOp>(
332
- b.create <ma::CmpFOp>(ma::CmpFPredicate::OGE, value, cst_float (b, max)),
333
- cst_int (b, max), clamped);
334
- // isnan(value) ? 0 : ...
335
- return b.create <ma::SelectOp>(
336
- b.create <ma::CmpFOp>(ma::CmpFPredicate::UNO, value, value),
337
- cst_int (b, 0 ), clamped);
338
- }
339
-
340
- LOG (FATAL) << " Type conversion not supported: "
341
- << llvm_ir::DumpToString (src_element_ty) << " -> "
342
- << llvm_ir::DumpToString (dst_element_ty);
343
- }
344
-
345
230
Value Subtract (EmitterLocOpBuilder b, ValueRange values) {
346
231
if (mlir::isa<mlir::IntegerType>(mlir::getElementTypeOrSelf (values[0 ]))) {
347
232
return b.create <ma::SubIOp>(values[0 ], values[1 ]);
@@ -448,7 +333,7 @@ absl::StatusOr<Value> EmitElementwise(EmitterLocOpBuilder b,
448
333
case HloOpcode::kConvert : {
449
334
TF_ASSIGN_OR_RETURN (Type dst_ty,
450
335
TritonType (b, hlo.shape ().element_type ()));
451
- return Cast (b, inputs[0 ], dst_ty);
336
+ return triton:: Cast (b, inputs[0 ], dst_ty);
452
337
}
453
338
case HloOpcode::kAdd :
454
339
if (is_integer) {
@@ -661,7 +546,7 @@ absl::StatusOr<Value> EmitScope(
661
546
if (hlo->opcode () == HloOpcode::kConvert &&
662
547
hlo->operand (0 )->shape ().element_type () == S4) {
663
548
Value unpacked;
664
- unpacked = Cast (b, values[hlo->operand (0 )], b.getI8Type ());
549
+ unpacked = triton:: Cast (b, values[hlo->operand (0 )], b.getI8Type ());
665
550
std::vector<Value> operands ({unpacked});
666
551
TF_ASSIGN_OR_RETURN (result, EmitElementwise (b, libdevice_path,
667
552
device_info, *hlo, operands));
@@ -817,7 +702,7 @@ ma::ConstantOp Cst64(EmitterLocOpBuilder b, int64_t v) {
817
702
}
818
703
819
704
Value RoundToBF16 (EmitterLocOpBuilder b, Value input) {
820
- return Cast (b, input, b.getBF16Type ());
705
+ return triton:: Cast (b, input, b.getBF16Type ());
821
706
};
822
707
823
708
/* static*/ absl::StatusOr<MatMulDims> MatMulDims::Create (
@@ -1480,7 +1365,7 @@ class MatMulEmitterHelper {
1480
1365
" 64 bit dynamic-slice indices are not supported yet." );
1481
1366
}
1482
1367
majormost_dim_start_index_val =
1483
- Cast (b, majormost_dim_start_index_val, b.getI32Type ());
1368
+ triton:: Cast (b, majormost_dim_start_index_val, b.getI32Type ());
1484
1369
majormost_dim_start_index_val =
1485
1370
b.create <ma::MaxSIOp>(majormost_dim_start_index_val, Cst32 (b, 0 ));
1486
1371
majormost_dim_start_index_val =
@@ -2041,7 +1926,7 @@ class IterableInput {
2041
1926
Value param_value = EmitParameterLoad (b, args.front (), boundary_checks_);
2042
1927
if (type_ != storage_type_) {
2043
1928
// For example cast i8 to i1.
2044
- param_value = Cast (b, param_value, type_);
1929
+ param_value = triton:: Cast (b, param_value, type_);
2045
1930
}
2046
1931
return param_value;
2047
1932
}
@@ -2167,10 +2052,10 @@ Value EmitRegularMatmul(EmitterLocOpBuilder& b, Value lhs, Value rhs, Value acc,
2167
2052
if (dot_instr->precision_config ().algorithm () ==
2168
2053
PrecisionConfig::ALG_DOT_BF16_BF16_F32) {
2169
2054
if (dot_instr->operand (0 )->shape ().element_type () == F32) {
2170
- lhs = Cast (b, lhs, b.getBF16Type ());
2055
+ lhs = triton:: Cast (b, lhs, b.getBF16Type ());
2171
2056
}
2172
2057
if (dot_instr->operand (1 )->shape ().element_type () == F32) {
2173
- rhs = Cast (b, rhs, b.getBF16Type ());
2058
+ rhs = triton:: Cast (b, rhs, b.getBF16Type ());
2174
2059
}
2175
2060
}
2176
2061
@@ -2364,7 +2249,7 @@ absl::StatusOr<std::optional<stream_executor::gpu::TmaMetadata>> EmitMatMul(
2364
2249
absl::flat_hash_map<const HloInstruction*, Value> values_out;
2365
2250
TF_ASSIGN_OR_RETURN (Type acc_final_ty,
2366
2251
TritonType (b, dot_instr->shape ().element_type ()));
2367
- values_out[dot_instr] = Cast (b, acc_final, acc_final_ty);
2252
+ values_out[dot_instr] = triton:: Cast (b, acc_final, acc_final_ty);
2368
2253
2369
2254
// Emit the output scope.
2370
2255
if (std::vector<const HloInstruction*> to_emit =
0 commit comments