@@ -227,130 +227,6 @@ bool IsFp8Type(Type t) {
227
227
mlir::Float8E4M3B11FNUZType>(t);
228
228
}
229
229
230
- Value Cast (EmitterLocOpBuilder b, Value value, Type dst_element_ty) {
231
- Type src_ty = value.getType ();
232
- Type src_element_ty = src_ty;
233
- Type fp16_ty = b.getF16Type ();
234
- Type fp32_ty = b.getF32Type ();
235
- Type dst_ty = dst_element_ty;
236
- if (auto src_shaped_ty = mlir::dyn_cast<ShapedType>(src_ty)) {
237
- src_element_ty = src_shaped_ty.getElementType ();
238
- dst_ty = src_shaped_ty.clone (src_shaped_ty.getShape (), dst_element_ty);
239
- fp16_ty = src_shaped_ty.clone (src_shaped_ty.getShape (), b.getF16Type ());
240
- fp32_ty = src_shaped_ty.clone (src_shaped_ty.getShape (), b.getF32Type ());
241
- }
242
- if (src_ty == dst_ty) {
243
- return value;
244
- }
245
-
246
- // All operations on bf16 are done through f32.
247
- if (src_element_ty.isBF16 ()) {
248
- return Cast (b, b.create <ma::ExtFOp>(fp32_ty, value), dst_element_ty);
249
- }
250
- if (dst_element_ty.isBF16 ()) {
251
- // S8 -> BF16 is directly supported and doesn't need to go through f32.
252
- if (!src_element_ty.isInteger (8 )) {
253
- return b.create <ma::TruncFOp>(dst_ty, Cast (b, value, b.getF32Type ()));
254
- }
255
- }
256
-
257
- // float => float
258
- auto src_fp_element_ty = mlir::dyn_cast<mlir::FloatType>(src_element_ty);
259
- auto dst_fp_element_ty = mlir::dyn_cast<mlir::FloatType>(dst_element_ty);
260
- if (src_fp_element_ty && dst_fp_element_ty) {
261
- // F8 <-> FP16, BF16, FP32, FP64 need to be handled via Triton's tt.fp_to_fp
262
- // because LLVM doesn't support casts from/to FP8.
263
- // TODO(b/266862493): Add end-to-end test once FP8 support lands in XLA as
264
- // we can't test the code below without patching the feature.
265
- if (IsFp8Type (src_element_ty) && !IsFp8Type (dst_element_ty)) {
266
- return b.create <mt::FpToFpOp>(dst_ty, value);
267
- }
268
- if (IsFp8Type (dst_element_ty) && !IsFp8Type (src_element_ty)) {
269
- return b.create <mt::FpToFpOp>(
270
- dst_ty, value,
271
- mt::RoundingModeAttr::get (b.getContext (), mt::RoundingMode::RTNE));
272
- }
273
- if (IsFp8Type (src_element_ty) && IsFp8Type (dst_element_ty)) {
274
- // FP8 <-> FP8 conversion needs to go through FP16
275
- auto fp16_value = b.create <mt::FpToFpOp>(fp16_ty, value);
276
- return b.create <mt::FpToFpOp>(
277
- dst_ty, fp16_value,
278
- mt::RoundingModeAttr::get (b.getContext (), mt::RoundingMode::RTNE));
279
- }
280
-
281
- if (src_fp_element_ty.getFPMantissaWidth () >
282
- dst_fp_element_ty.getFPMantissaWidth ()) {
283
- return b.create <ma::TruncFOp>(dst_ty, value);
284
- } else {
285
- return b.create <ma::ExtFOp>(dst_ty, value);
286
- }
287
- }
288
- // int => int
289
- if (mlir::isa<mlir::IntegerType>(src_element_ty) &&
290
- mlir::isa<mlir::IntegerType>(dst_element_ty)) {
291
- if (src_element_ty.getIntOrFloatBitWidth () <
292
- dst_element_ty.getIntOrFloatBitWidth ()) {
293
- if (src_element_ty.isInteger (1 )) {
294
- return b.create <ma::ExtUIOp>(dst_ty, value);
295
- }
296
- return b.create <ma::ExtSIOp>(dst_ty, value);
297
- }
298
- return b.create <ma::TruncIOp>(dst_ty, value);
299
- }
300
- // int => float
301
- if (mlir::isa<mlir::IntegerType>(src_element_ty) && dst_fp_element_ty) {
302
- // TODO(b/266862493): Support unsigned integer types.
303
- if (src_element_ty.isInteger (1 )) {
304
- return b.create <ma::UIToFPOp>(dst_ty, value);
305
- }
306
- return b.create <ma::SIToFPOp>(dst_ty, value);
307
- }
308
- // float => int
309
- if (src_fp_element_ty && mlir::isa<mlir::IntegerType>(dst_element_ty)) {
310
- if (dst_element_ty.isInteger (1 )) {
311
- return b.create <ma::CmpFOp>(ma::CmpFPredicate::UNE, value,
312
- ZerosLike (b, value));
313
- }
314
- // TODO(b/266862493): Support unsigned integer types.
315
- // The current logic handles signed integer types only. Additional handling
316
- // is needed for unsigned integer types.
317
- auto cst_int = [&](EmitterLocOpBuilder b, int64_t x) {
318
- if (auto src_shaped_ty = mlir::dyn_cast<ShapedType>(src_ty)) {
319
- return CreateConst (b, dst_element_ty, x, src_shaped_ty.getShape ());
320
- } else {
321
- return CreateConst (b, dst_element_ty, x);
322
- }
323
- };
324
- auto cst_float = [&](EmitterLocOpBuilder b, int64_t x) {
325
- if (auto src_shaped_ty = mlir::dyn_cast<ShapedType>(src_ty)) {
326
- return CreateConst (b, src_fp_element_ty, x, src_shaped_ty.getShape ());
327
- } else {
328
- return CreateConst (b, src_fp_element_ty, x);
329
- }
330
- };
331
- auto fptosi = b.create <ma::FPToSIOp>(dst_ty, value);
332
- int64_t min = llvm::minIntN (dst_element_ty.getIntOrFloatBitWidth ());
333
- int64_t max = llvm::maxIntN (dst_element_ty.getIntOrFloatBitWidth ());
334
-
335
- // value <= static_cast<float>(INT_MIN) ? INT_MIN : ...
336
- auto clamped = b.create <ma::SelectOp>(
337
- b.create <ma::CmpFOp>(ma::CmpFPredicate::OLE, value, cst_float (b, min)),
338
- cst_int (b, min), fptosi);
339
- // value >= static_cast<float>(INT_MAX) ? INT_MAX : ...
340
- clamped = b.create <ma::SelectOp>(
341
- b.create <ma::CmpFOp>(ma::CmpFPredicate::OGE, value, cst_float (b, max)),
342
- cst_int (b, max), clamped);
343
- // isnan(value) ? 0 : ...
344
- return b.create <ma::SelectOp>(
345
- b.create <ma::CmpFOp>(ma::CmpFPredicate::UNO, value, value),
346
- cst_int (b, 0 ), clamped);
347
- }
348
-
349
- LOG (FATAL) << " Type conversion not supported: "
350
- << llvm_ir::DumpToString (src_element_ty) << " -> "
351
- << llvm_ir::DumpToString (dst_element_ty);
352
- }
353
-
354
230
Value Subtract (EmitterLocOpBuilder b, ValueRange values) {
355
231
if (mlir::isa<mlir::IntegerType>(mlir::getElementTypeOrSelf (values[0 ]))) {
356
232
return b.create <ma::SubIOp>(values[0 ], values[1 ]);
@@ -457,7 +333,7 @@ absl::StatusOr<Value> EmitElementwise(EmitterLocOpBuilder b,
457
333
case HloOpcode::kConvert : {
458
334
TF_ASSIGN_OR_RETURN (Type dst_ty,
459
335
TritonType (b, hlo.shape ().element_type ()));
460
- return Cast (b, inputs[0 ], dst_ty);
336
+ return triton:: Cast (b, inputs[0 ], dst_ty);
461
337
}
462
338
case HloOpcode::kAdd :
463
339
if (is_integer) {
@@ -670,7 +546,7 @@ absl::StatusOr<Value> EmitScope(
670
546
if (hlo->opcode () == HloOpcode::kConvert &&
671
547
hlo->operand (0 )->shape ().element_type () == S4) {
672
548
Value unpacked;
673
- unpacked = Cast (b, values[hlo->operand (0 )], b.getI8Type ());
549
+ unpacked = triton:: Cast (b, values[hlo->operand (0 )], b.getI8Type ());
674
550
std::vector<Value> operands ({unpacked});
675
551
TF_ASSIGN_OR_RETURN (result, EmitElementwise (b, libdevice_path,
676
552
device_info, *hlo, operands));
@@ -826,7 +702,7 @@ ma::ConstantOp Cst64(EmitterLocOpBuilder b, int64_t v) {
826
702
}
827
703
828
704
Value RoundToBF16 (EmitterLocOpBuilder b, Value input) {
829
- return Cast (b, input, b.getBF16Type ());
705
+ return triton:: Cast (b, input, b.getBF16Type ());
830
706
};
831
707
832
708
/* static*/ absl::StatusOr<MatMulDims> MatMulDims::Create (
@@ -1487,7 +1363,7 @@ class MatMulEmitterHelper {
1487
1363
" 64 bit dynamic-slice indices are not supported yet." );
1488
1364
}
1489
1365
majormost_dim_start_index_val =
1490
- Cast (b, majormost_dim_start_index_val, b.getI32Type ());
1366
+ triton:: Cast (b, majormost_dim_start_index_val, b.getI32Type ());
1491
1367
majormost_dim_start_index_val =
1492
1368
b.create <ma::MaxSIOp>(majormost_dim_start_index_val, Cst32 (b, 0 ));
1493
1369
majormost_dim_start_index_val =
@@ -2049,7 +1925,7 @@ class IterableInput {
2049
1925
Value param_value = EmitParameterLoad (b, args.front (), boundary_checks_);
2050
1926
if (type_ != storage_type_) {
2051
1927
// For example cast i8 to i1.
2052
- param_value = Cast (b, param_value, type_);
1928
+ param_value = triton:: Cast (b, param_value, type_);
2053
1929
}
2054
1930
return param_value;
2055
1931
}
@@ -2175,10 +2051,10 @@ Value EmitRegularMatmul(EmitterLocOpBuilder& b, Value lhs, Value rhs, Value acc,
2175
2051
if (dot_instr->precision_config ().algorithm () ==
2176
2052
PrecisionConfig::ALG_DOT_BF16_BF16_F32) {
2177
2053
if (dot_instr->operand (0 )->shape ().element_type () == F32) {
2178
- lhs = Cast (b, lhs, b.getBF16Type ());
2054
+ lhs = triton:: Cast (b, lhs, b.getBF16Type ());
2179
2055
}
2180
2056
if (dot_instr->operand (1 )->shape ().element_type () == F32) {
2181
- rhs = Cast (b, rhs, b.getBF16Type ());
2057
+ rhs = triton:: Cast (b, rhs, b.getBF16Type ());
2182
2058
}
2183
2059
}
2184
2060
@@ -2372,7 +2248,7 @@ absl::StatusOr<std::optional<stream_executor::gpu::TmaMetadata>> EmitMatMul(
2372
2248
absl::flat_hash_map<const HloInstruction*, Value> values_out;
2373
2249
TF_ASSIGN_OR_RETURN (Type acc_final_ty,
2374
2250
TritonType (b, dot_instr->shape ().element_type ()));
2375
- values_out[dot_instr] = Cast (b, acc_final, acc_final_ty);
2251
+ values_out[dot_instr] = triton:: Cast (b, acc_final, acc_final_ty);
2376
2252
2377
2253
// Emit the output scope.
2378
2254
if (std::vector<const HloInstruction*> to_emit =
0 commit comments