@@ -1019,6 +1019,54 @@ void CodeGenLLVM::visit(RangeForStmt *for_stmt) {
10191019 create_naive_range_for (for_stmt);
10201020}
10211021
1022+ llvm::Value *CodeGenLLVM::bitcast_from_u64 (llvm::Value *val, DataType type) {
1023+ llvm::Type *dest_ty = nullptr ;
1024+ TI_ASSERT (!type->is <PointerType>());
1025+ if (auto cit = type->cast <CustomIntType>()) {
1026+ if (cit->get_is_signed ())
1027+ dest_ty = tlctx->get_data_type (PrimitiveType::i32 );
1028+ else
1029+ dest_ty = tlctx->get_data_type (PrimitiveType::u32 );
1030+ } else {
1031+ dest_ty = tlctx->get_data_type (type);
1032+ }
1033+ auto dest_bits = dest_ty->getPrimitiveSizeInBits ();
1034+ if (dest_ty == llvm::Type::getHalfTy (*llvm_context)) {
1035+ // if dest_ty == half, CreateTrunc will only keep low 16bits of mantissa
1036+ // which doesn't mean anything.
1037+ // So we truncate to 32 bits first and then fptrunc to half if applicable
1038+ auto truncated =
1039+ builder->CreateTrunc (val, llvm::Type::getIntNTy (*llvm_context, 32 ));
1040+ auto casted = builder->CreateBitCast (truncated,
1041+ llvm::Type::getFloatTy (*llvm_context));
1042+ return builder->CreateFPTrunc (casted, llvm::Type::getHalfTy (*llvm_context));
1043+ } else {
1044+ auto truncated = builder->CreateTrunc (
1045+ val, llvm::Type::getIntNTy (*llvm_context, dest_bits));
1046+
1047+ return builder->CreateBitCast (truncated, dest_ty);
1048+ }
1049+ }
1050+
1051+ llvm::Value *CodeGenLLVM::bitcast_to_u64 (llvm::Value *val, DataType type) {
1052+ auto intermediate_bits = 0 ;
1053+ if (auto cit = type->cast <CustomIntType>()) {
1054+ intermediate_bits = data_type_bits (cit->get_compute_type ());
1055+ } else {
1056+ intermediate_bits = tlctx->get_data_type (type)->getPrimitiveSizeInBits ();
1057+ }
1058+ llvm::Type *dest_ty = tlctx->get_data_type <int64>();
1059+ llvm::Type *intermediate_type = nullptr ;
1060+ if (val->getType () == llvm::Type::getHalfTy (*llvm_context)) {
1061+ val = builder->CreateFPExt (val, tlctx->get_data_type <float >());
1062+ intermediate_type = tlctx->get_data_type <int32>();
1063+ } else {
1064+ intermediate_type = llvm::Type::getIntNTy (*llvm_context, intermediate_bits);
1065+ }
1066+ return builder->CreateZExt (builder->CreateBitCast (val, intermediate_type),
1067+ dest_ty);
1068+ }
1069+
10221070void CodeGenLLVM::visit (ArgLoadStmt *stmt) {
10231071 auto raw_arg = call (builder.get (), " RuntimeContext_get_args" , get_context (),
10241072 tlctx->get_constant (stmt->arg_id ));
@@ -1029,32 +1077,7 @@ void CodeGenLLVM::visit(ArgLoadStmt *stmt) {
10291077 llvm::PointerType::get (tlctx->get_data_type (PrimitiveType::i32 ), 0 );
10301078 llvm_val[stmt] = builder->CreateIntToPtr (raw_arg, dest_ty);
10311079 } else {
1032- TI_ASSERT (!stmt->ret_type ->is <PointerType>());
1033- if (auto cit = stmt->ret_type ->cast <CustomIntType>()) {
1034- if (cit->get_is_signed ())
1035- dest_ty = tlctx->get_data_type (PrimitiveType::i32 );
1036- else
1037- dest_ty = tlctx->get_data_type (PrimitiveType::u32 );
1038- } else {
1039- dest_ty = tlctx->get_data_type (stmt->ret_type );
1040- }
1041- auto dest_bits = dest_ty->getPrimitiveSizeInBits ();
1042- if (dest_ty == llvm::Type::getHalfTy (*llvm_context)) {
1043- // if dest_ty == half, CreateTrunc will only keep low 16bits of mantissa
1044- // which doesn't mean anything.
1045- // So we truncate to 32 bits first and then fptrunc to half if applicable
1046- auto truncated = builder->CreateTrunc (
1047- raw_arg, llvm::Type::getIntNTy (*llvm_context, 32 ));
1048- auto casted = builder->CreateBitCast (
1049- truncated, llvm::Type::getFloatTy (*llvm_context));
1050- llvm_val[stmt] =
1051- builder->CreateFPTrunc (casted, llvm::Type::getHalfTy (*llvm_context));
1052- } else {
1053- auto truncated = builder->CreateTrunc (
1054- raw_arg, llvm::Type::getIntNTy (*llvm_context, dest_bits));
1055-
1056- llvm_val[stmt] = builder->CreateBitCast (truncated, dest_ty);
1057- }
1080+ llvm_val[stmt] = bitcast_from_u64 (raw_arg, stmt->ret_type );
10581081 }
10591082}
10601083
@@ -1067,27 +1090,10 @@ void CodeGenLLVM::visit(ReturnStmt *stmt) {
10671090 TI_ASSERT (stmt->values .size () <= taichi_max_num_ret_value);
10681091 int idx{0 };
10691092 for (auto &value : stmt->values ) {
1070- auto intermediate_bits = 0 ;
1071- if (auto cit = value->ret_type ->cast <CustomIntType>()) {
1072- intermediate_bits = data_type_bits (cit->get_compute_type ());
1073- } else {
1074- intermediate_bits =
1075- tlctx->get_data_type (value->ret_type )->getPrimitiveSizeInBits ();
1076- }
1077- llvm::Type *dest_ty = tlctx->get_data_type <int64>();
1078- llvm::Type *intermediate_type = nullptr ;
1079- if (llvm_val[value]->getType () == llvm::Type::getHalfTy (*llvm_context)) {
1080- llvm_val[value] = builder->CreateFPExt (llvm_val[value],
1081- tlctx->get_data_type <float >());
1082- intermediate_type = tlctx->get_data_type <int32>();
1083- } else {
1084- intermediate_type =
1085- llvm::Type::getIntNTy (*llvm_context, intermediate_bits);
1086- }
1087- auto extended = builder->CreateZExt (
1088- builder->CreateBitCast (llvm_val[value], intermediate_type), dest_ty);
1089- create_call (" LLVMRuntime_store_result" ,
1090- {get_runtime (), extended, tlctx->get_constant <int32>(idx++)});
1093+ create_call (
1094+ " RuntimeContext_store_result" ,
1095+ {get_context (), bitcast_to_u64 (llvm_val[value], value->ret_type ),
1096+ tlctx->get_constant <int32>(idx++)});
10911097 }
10921098 }
10931099}
@@ -2387,17 +2393,22 @@ void CodeGenLLVM::visit(FuncCallStmt *stmt) {
23872393 auto *new_ctx = builder->CreateAlloca (get_runtime_type (" RuntimeContext" ));
23882394 call (" RuntimeContext_set_runtime" , new_ctx, get_runtime ());
23892395 for (int i = 0 ; i < stmt->args .size (); i++) {
2390- auto *original = llvm_val[stmt->args [i]];
2391- int src_bits = original->getType ()->getPrimitiveSizeInBits ();
2392- auto *cast = builder->CreateBitCast (
2393- original, llvm::Type::getIntNTy (*llvm_context, src_bits));
23942396 auto *val =
2395- builder-> CreateZExt (cast, llvm::Type::getInt64Ty (*llvm_context) );
2397+ bitcast_to_u64 (llvm_val[stmt-> args [i]], stmt-> args [i]-> ret_type );
23962398 call (" RuntimeContext_set_args" , new_ctx,
23972399 llvm::ConstantInt::get (*llvm_context, llvm::APInt (32 , i, true )), val);
23982400 }
2399-
2400- llvm_val[stmt] = create_call (llvm_func, {new_ctx});
2401+ llvm::Value *result_buffer = nullptr ;
2402+ if (stmt->ret_type ->is <PrimitiveType>() &&
2403+ !stmt->ret_type ->is_primitive (PrimitiveTypeID::unknown)) {
2404+ result_buffer = builder->CreateAlloca (tlctx->get_data_type <uint64>());
2405+ call (" RuntimeContext_set_result_buffer" , new_ctx, result_buffer);
2406+ create_call (llvm_func, {new_ctx});
2407+ auto *ret_val_u64 = builder->CreateLoad (result_buffer);
2408+ llvm_val[stmt] = bitcast_from_u64 (ret_val_u64, stmt->ret_type );
2409+ } else {
2410+ create_call (llvm_func, {new_ctx});
2411+ }
24012412}
24022413
24032414TLANG_NAMESPACE_END
0 commit comments