Skip to content

Commit 2c04cb2

Browse files
authored
[llvm] Support real function with single scalar return value (#4452)
* [llvm] Support real function with single scalar return value * add comment
1 parent f714872 commit 2c04cb2

File tree

6 files changed

+116
-97
lines changed

6 files changed

+116
-97
lines changed

taichi/codegen/codegen_llvm.cpp

Lines changed: 65 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,6 +1019,54 @@ void CodeGenLLVM::visit(RangeForStmt *for_stmt) {
10191019
create_naive_range_for(for_stmt);
10201020
}
10211021

1022+
llvm::Value *CodeGenLLVM::bitcast_from_u64(llvm::Value *val, DataType type) {
1023+
llvm::Type *dest_ty = nullptr;
1024+
TI_ASSERT(!type->is<PointerType>());
1025+
if (auto cit = type->cast<CustomIntType>()) {
1026+
if (cit->get_is_signed())
1027+
dest_ty = tlctx->get_data_type(PrimitiveType::i32);
1028+
else
1029+
dest_ty = tlctx->get_data_type(PrimitiveType::u32);
1030+
} else {
1031+
dest_ty = tlctx->get_data_type(type);
1032+
}
1033+
auto dest_bits = dest_ty->getPrimitiveSizeInBits();
1034+
if (dest_ty == llvm::Type::getHalfTy(*llvm_context)) {
1035+
// if dest_ty == half, CreateTrunc will only keep low 16bits of mantissa
1036+
// which doesn't mean anything.
1037+
// So we truncate to 32 bits first and then fptrunc to half if applicable
1038+
auto truncated =
1039+
builder->CreateTrunc(val, llvm::Type::getIntNTy(*llvm_context, 32));
1040+
auto casted = builder->CreateBitCast(truncated,
1041+
llvm::Type::getFloatTy(*llvm_context));
1042+
return builder->CreateFPTrunc(casted, llvm::Type::getHalfTy(*llvm_context));
1043+
} else {
1044+
auto truncated = builder->CreateTrunc(
1045+
val, llvm::Type::getIntNTy(*llvm_context, dest_bits));
1046+
1047+
return builder->CreateBitCast(truncated, dest_ty);
1048+
}
1049+
}
1050+
1051+
llvm::Value *CodeGenLLVM::bitcast_to_u64(llvm::Value *val, DataType type) {
1052+
auto intermediate_bits = 0;
1053+
if (auto cit = type->cast<CustomIntType>()) {
1054+
intermediate_bits = data_type_bits(cit->get_compute_type());
1055+
} else {
1056+
intermediate_bits = tlctx->get_data_type(type)->getPrimitiveSizeInBits();
1057+
}
1058+
llvm::Type *dest_ty = tlctx->get_data_type<int64>();
1059+
llvm::Type *intermediate_type = nullptr;
1060+
if (val->getType() == llvm::Type::getHalfTy(*llvm_context)) {
1061+
val = builder->CreateFPExt(val, tlctx->get_data_type<float>());
1062+
intermediate_type = tlctx->get_data_type<int32>();
1063+
} else {
1064+
intermediate_type = llvm::Type::getIntNTy(*llvm_context, intermediate_bits);
1065+
}
1066+
return builder->CreateZExt(builder->CreateBitCast(val, intermediate_type),
1067+
dest_ty);
1068+
}
1069+
10221070
void CodeGenLLVM::visit(ArgLoadStmt *stmt) {
10231071
auto raw_arg = call(builder.get(), "RuntimeContext_get_args", get_context(),
10241072
tlctx->get_constant(stmt->arg_id));
@@ -1029,32 +1077,7 @@ void CodeGenLLVM::visit(ArgLoadStmt *stmt) {
10291077
llvm::PointerType::get(tlctx->get_data_type(PrimitiveType::i32), 0);
10301078
llvm_val[stmt] = builder->CreateIntToPtr(raw_arg, dest_ty);
10311079
} else {
1032-
TI_ASSERT(!stmt->ret_type->is<PointerType>());
1033-
if (auto cit = stmt->ret_type->cast<CustomIntType>()) {
1034-
if (cit->get_is_signed())
1035-
dest_ty = tlctx->get_data_type(PrimitiveType::i32);
1036-
else
1037-
dest_ty = tlctx->get_data_type(PrimitiveType::u32);
1038-
} else {
1039-
dest_ty = tlctx->get_data_type(stmt->ret_type);
1040-
}
1041-
auto dest_bits = dest_ty->getPrimitiveSizeInBits();
1042-
if (dest_ty == llvm::Type::getHalfTy(*llvm_context)) {
1043-
// if dest_ty == half, CreateTrunc will only keep low 16bits of mantissa
1044-
// which doesn't mean anything.
1045-
// So we truncate to 32 bits first and then fptrunc to half if applicable
1046-
auto truncated = builder->CreateTrunc(
1047-
raw_arg, llvm::Type::getIntNTy(*llvm_context, 32));
1048-
auto casted = builder->CreateBitCast(
1049-
truncated, llvm::Type::getFloatTy(*llvm_context));
1050-
llvm_val[stmt] =
1051-
builder->CreateFPTrunc(casted, llvm::Type::getHalfTy(*llvm_context));
1052-
} else {
1053-
auto truncated = builder->CreateTrunc(
1054-
raw_arg, llvm::Type::getIntNTy(*llvm_context, dest_bits));
1055-
1056-
llvm_val[stmt] = builder->CreateBitCast(truncated, dest_ty);
1057-
}
1080+
llvm_val[stmt] = bitcast_from_u64(raw_arg, stmt->ret_type);
10581081
}
10591082
}
10601083

@@ -1067,27 +1090,10 @@ void CodeGenLLVM::visit(ReturnStmt *stmt) {
10671090
TI_ASSERT(stmt->values.size() <= taichi_max_num_ret_value);
10681091
int idx{0};
10691092
for (auto &value : stmt->values) {
1070-
auto intermediate_bits = 0;
1071-
if (auto cit = value->ret_type->cast<CustomIntType>()) {
1072-
intermediate_bits = data_type_bits(cit->get_compute_type());
1073-
} else {
1074-
intermediate_bits =
1075-
tlctx->get_data_type(value->ret_type)->getPrimitiveSizeInBits();
1076-
}
1077-
llvm::Type *dest_ty = tlctx->get_data_type<int64>();
1078-
llvm::Type *intermediate_type = nullptr;
1079-
if (llvm_val[value]->getType() == llvm::Type::getHalfTy(*llvm_context)) {
1080-
llvm_val[value] = builder->CreateFPExt(llvm_val[value],
1081-
tlctx->get_data_type<float>());
1082-
intermediate_type = tlctx->get_data_type<int32>();
1083-
} else {
1084-
intermediate_type =
1085-
llvm::Type::getIntNTy(*llvm_context, intermediate_bits);
1086-
}
1087-
auto extended = builder->CreateZExt(
1088-
builder->CreateBitCast(llvm_val[value], intermediate_type), dest_ty);
1089-
create_call("LLVMRuntime_store_result",
1090-
{get_runtime(), extended, tlctx->get_constant<int32>(idx++)});
1093+
create_call(
1094+
"RuntimeContext_store_result",
1095+
{get_context(), bitcast_to_u64(llvm_val[value], value->ret_type),
1096+
tlctx->get_constant<int32>(idx++)});
10911097
}
10921098
}
10931099
}
@@ -2387,17 +2393,22 @@ void CodeGenLLVM::visit(FuncCallStmt *stmt) {
23872393
auto *new_ctx = builder->CreateAlloca(get_runtime_type("RuntimeContext"));
23882394
call("RuntimeContext_set_runtime", new_ctx, get_runtime());
23892395
for (int i = 0; i < stmt->args.size(); i++) {
2390-
auto *original = llvm_val[stmt->args[i]];
2391-
int src_bits = original->getType()->getPrimitiveSizeInBits();
2392-
auto *cast = builder->CreateBitCast(
2393-
original, llvm::Type::getIntNTy(*llvm_context, src_bits));
23942396
auto *val =
2395-
builder->CreateZExt(cast, llvm::Type::getInt64Ty(*llvm_context));
2397+
bitcast_to_u64(llvm_val[stmt->args[i]], stmt->args[i]->ret_type);
23962398
call("RuntimeContext_set_args", new_ctx,
23972399
llvm::ConstantInt::get(*llvm_context, llvm::APInt(32, i, true)), val);
23982400
}
2399-
2400-
llvm_val[stmt] = create_call(llvm_func, {new_ctx});
2401+
llvm::Value *result_buffer = nullptr;
2402+
if (stmt->ret_type->is<PrimitiveType>() &&
2403+
!stmt->ret_type->is_primitive(PrimitiveTypeID::unknown)) {
2404+
result_buffer = builder->CreateAlloca(tlctx->get_data_type<uint64>());
2405+
call("RuntimeContext_set_result_buffer", new_ctx, result_buffer);
2406+
create_call(llvm_func, {new_ctx});
2407+
auto *ret_val_u64 = builder->CreateLoad(result_buffer);
2408+
llvm_val[stmt] = bitcast_from_u64(ret_val_u64, stmt->ret_type);
2409+
} else {
2410+
create_call(llvm_func, {new_ctx});
2411+
}
24012412
}
24022413

24032414
TLANG_NAMESPACE_END

taichi/codegen/codegen_llvm.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,9 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
382382

383383
void visit(FuncCallStmt *stmt) override;
384384

385+
llvm::Value *bitcast_from_u64(llvm::Value *val, DataType type);
386+
llvm::Value *bitcast_to_u64(llvm::Value *val, DataType type);
387+
385388
~CodeGenLLVM() override = default;
386389
};
387390

taichi/program/context.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ struct RuntimeContext {
2424
int32 cpu_thread_id;
2525
// |is_device_allocation| is true iff args[i] is a DeviceAllocation*.
2626
bool is_device_allocation[taichi_max_num_args_total]{false};
27+
// We move the pointer of result buffer from LLVMRuntime to RuntimeContext
28+
// because each real function need a place to store its result, but
29+
// LLVMRuntime is shared among functions. So we moved the pointer to
30+
// RuntimeContext which each function have one.
31+
uint64 *result_buffer;
2732

2833
static constexpr size_t extra_args_size = sizeof(extra_args);
2934

taichi/program/kernel.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ RuntimeContext &Kernel::LaunchContextBuilder::get_context() {
285285
ctx_->runtime = llvm_program_impl->get_llvm_runtime();
286286
}
287287
#endif
288+
ctx_->result_buffer = kernel_->program->result_buffer;
288289
return *ctx_;
289290
}
290291

taichi/runtime/llvm/runtime.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,7 @@ STRUCT_FIELD_ARRAY(PhysicalCoordinates, val);
348348

349349
STRUCT_FIELD_ARRAY(RuntimeContext, args);
350350
STRUCT_FIELD(RuntimeContext, runtime);
351+
STRUCT_FIELD(RuntimeContext, result_buffer)
351352

352353
int32 RuntimeContext_get_extra_args(RuntimeContext *ctx, int32 i, int32 j) {
353354
return ctx->extra_args[i][j];
@@ -696,8 +697,8 @@ struct NodeManager {
696697

697698
extern "C" {
698699

699-
void LLVMRuntime_store_result(LLVMRuntime *runtime, u64 ret, u32 idx) {
700-
runtime->set_result(taichi_result_buffer_ret_value_id + idx, ret);
700+
void RuntimeContext_store_result(RuntimeContext *ctx, u64 ret, u32 idx) {
701+
ctx->result_buffer[taichi_result_buffer_ret_value_id + idx] = ret;
701702
}
702703

703704
void LLVMRuntime_profiler_start(LLVMRuntime *runtime, Ptr kernel_name) {

tests/python/test_function.py

Lines changed: 39 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -22,47 +22,45 @@ def run():
2222
assert x[None] == 42
2323

2424

25-
# @test_utils.test(arch=[ti.cpu, ti.gpu])
26-
# def test_function_with_return():
27-
# x = ti.field(ti.i32, shape=())
28-
#
29-
# @ti.experimental.real_func
30-
# def foo(val: ti.i32) -> ti.i32:
31-
# x[None] += val
32-
# return val
33-
#
34-
# @ti.kernel
35-
# def run():
36-
# a = foo(40)
37-
# foo(2)
38-
# assert a == 40
39-
#
40-
# x[None] = 0
41-
# run()
42-
# assert x[None] == 42
43-
#
44-
#
45-
# @test_utils.test(arch=[ti.cpu, ti.gpu])
46-
# def test_call_expressions():
47-
# x = ti.field(ti.i32, shape=())
48-
#
49-
# @ti.experimental.real_func
50-
# def foo(val: ti.i32) -> ti.i32:
51-
# if x[None] > 10:
52-
# x[None] += 1
53-
# x[None] += val
54-
# return 0
55-
#
56-
# @ti.kernel
57-
# def run():
58-
# assert foo(15) == 0
59-
# assert foo(10) == 0
60-
#
61-
# x[None] = 0
62-
# run()
63-
# assert x[None] == 26
64-
#
65-
#
25+
@test_utils.test(arch=[ti.cpu, ti.gpu], debug=True)
26+
def test_function_with_return():
27+
x = ti.field(ti.i32, shape=())
28+
29+
@ti.experimental.real_func
30+
def foo(val: ti.i32) -> ti.i32:
31+
x[None] += val
32+
return val
33+
34+
@ti.kernel
35+
def run():
36+
a = foo(40)
37+
foo(2)
38+
assert a == 40
39+
40+
x[None] = 0
41+
run()
42+
assert x[None] == 42
43+
44+
45+
@test_utils.test(arch=[ti.cpu, ti.gpu])
46+
def test_call_expressions():
47+
x = ti.field(ti.i32, shape=())
48+
49+
@ti.experimental.real_func
50+
def foo(val: ti.i32) -> ti.i32:
51+
if x[None] > 10:
52+
x[None] += 1
53+
x[None] += val
54+
return 0
55+
56+
@ti.kernel
57+
def run():
58+
assert foo(15) == 0
59+
assert foo(10) == 0
60+
61+
x[None] = 0
62+
run()
63+
assert x[None] == 26
6664

6765

6866
@test_utils.test(arch=[ti.cpu, ti.cuda], debug=True)

0 commit comments

Comments
 (0)