Skip to content

Commit

Permalink
Model function calls as initializing expressions (#3089)
Browse files Browse the repository at this point in the history
Start treating function calls as initializing expressions instead of as
value expressions.

This required adding support for expression categories. Value bindings
and temporary materialization conversions are created where necessary to
transition between expression categories. For a function call with a
return slot, we speculatively create a materialized temporary before the
call and either commit to it or replace it with something else later,
once we see how the function call expression is actually used.

This change follows the direction suggested in #3133 for initializing
expressions: depending on the return type of a function, the return
value will either be initialized in-place or returned directly. This is
visible in the semantics IR, which is a little unfortunate but is
probably necessary as this is part of the semantics of the program.

---------

Co-authored-by: Chandler Carruth <[email protected]>
  • Loading branch information
zygoloid and chandlerc committed Aug 24, 2023
1 parent f790a27 commit 1013d17
Show file tree
Hide file tree
Showing 101 changed files with 1,749 additions and 724 deletions.
55 changes: 43 additions & 12 deletions toolchain/lowering/lowering_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,39 +57,63 @@ auto LoweringContext::Run() -> std::unique_ptr<llvm::Module> {

auto LoweringContext::BuildFunctionDeclaration(SemIR::FunctionId function_id)
-> llvm::Function* {
auto function = semantics_ir().GetFunction(function_id);
const auto& function = semantics_ir().GetFunction(function_id);
const bool has_return_slot = function.return_slot_id.is_valid();
const int first_param = has_return_slot ? 1 : 0;

SemIR::InitializingRepresentation return_rep =
function.return_type_id.is_valid()
? SemIR::GetInitializingRepresentation(semantics_ir(),
function.return_type_id)
: SemIR::InitializingRepresentation{
.kind = SemIR::InitializingRepresentation::None};
CARBON_CHECK(return_rep.has_return_slot() == has_return_slot);

// TODO: Lower type information for the arguments prior to building args.
auto param_refs = semantics_ir().GetNodeBlock(function.param_refs_id);
llvm::SmallVector<llvm::Type*> args;
args.resize_for_overwrite(param_refs.size());
args.resize_for_overwrite(first_param + param_refs.size());
if (has_return_slot) {
args[0] = GetType(function.return_type_id)->getPointerTo();
}
for (auto [i, param_ref] : llvm::enumerate(param_refs)) {
args[i] = GetType(semantics_ir().GetNode(param_ref).type_id());
args[first_param + i] =
GetType(semantics_ir().GetNode(param_ref).type_id());
}

// If return type is not valid, the function does not have a return type.
// Hence, set return type to void.
llvm::Type* return_type = function.return_type_id.is_valid()
? GetType(function.return_type_id)
: llvm::Type::getVoidTy(llvm_context());
// If the initializing representation doesn't produce a value, set the return
// type to void.
llvm::Type* return_type =
return_rep.kind == SemIR::InitializingRepresentation::ByCopy
? GetType(function.return_type_id)
: llvm::Type::getVoidTy(llvm_context());

llvm::FunctionType* function_type =
llvm::FunctionType::get(return_type, args, /*isVarArg=*/false);
auto* llvm_function = llvm::Function::Create(
function_type, llvm::Function::ExternalLinkage,
semantics_ir().GetString(function.name_id), llvm_module());

if (has_return_slot) {
auto* return_slot = llvm_function->getArg(0);
return_slot->addAttr(llvm::Attribute::getWithStructRetType(
llvm_context(), GetType(function.return_type_id)));
return_slot->setName("return");
}

// Set parameter names.
for (auto [i, param_ref] : llvm::enumerate(param_refs)) {
auto name_id = semantics_ir().GetNode(param_ref).GetAsParameter();
llvm_function->getArg(i)->setName(semantics_ir().GetString(name_id));
llvm_function->getArg(first_param + i)
->setName(semantics_ir().GetString(name_id));
}

return llvm_function;
}

auto LoweringContext::BuildFunctionDefinition(SemIR::FunctionId function_id)
-> void {
auto function = semantics_ir().GetFunction(function_id);
const auto& function = semantics_ir().GetFunction(function_id);
const auto& body_block_ids = function.body_block_ids;
if (body_block_ids.empty()) {
// Function is probably defined in another file; not an error.
Expand All @@ -99,14 +123,21 @@ auto LoweringContext::BuildFunctionDefinition(SemIR::FunctionId function_id)
llvm::Function* llvm_function = GetFunction(function_id);
LoweringFunctionContext function_lowering(*this, llvm_function);

const bool has_return_slot = function.return_slot_id.is_valid();
const int first_param = has_return_slot ? 1 : 0;

// Add parameters to locals.
auto param_refs = semantics_ir().GetNodeBlock(function.param_refs_id);
if (has_return_slot) {
function_lowering.SetLocal(function.return_slot_id,
llvm_function->getArg(0));
}
for (auto [i, param_ref] : llvm::enumerate(param_refs)) {
function_lowering.SetLocal(param_ref, llvm_function->getArg(i));
function_lowering.SetLocal(param_ref,
llvm_function->getArg(first_param + i));
}

// Lower all blocks.
// TODO: Determine the set of reachable blocks, and only lower those ones.
for (auto block_id : body_block_ids) {
CARBON_VLOG() << "Lowering " << block_id << "\n";
auto* llvm_block = function_lowering.GetBlock(block_id);
Expand Down
39 changes: 28 additions & 11 deletions toolchain/lowering/lowering_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,21 +172,32 @@ auto LoweringHandleBuiltin(LoweringFunctionContext& /*context*/,
auto LoweringHandleCall(LoweringFunctionContext& context, SemIR::NodeId node_id,
SemIR::Node node) -> void {
auto [refs_id, function_id] = node.GetAsCall();
auto* function = context.GetFunction(function_id);
auto* llvm_function = context.GetFunction(function_id);
const auto& function = context.semantics_ir().GetFunction(function_id);

std::vector<llvm::Value*> args;
for (auto ref_id : context.semantics_ir().GetNodeBlock(refs_id)) {
llvm::ArrayRef<SemIR::NodeId> arg_ids =
context.semantics_ir().GetNodeBlock(refs_id);

if (function.return_slot_id.is_valid()) {
args.push_back(context.GetLocal(arg_ids.back()));
arg_ids = arg_ids.drop_back();
}

for (auto ref_id : arg_ids) {
args.push_back(context.GetLocalLoaded(ref_id));
}
if (function->getReturnType()->isVoidTy()) {
context.builder().CreateCall(function, args);
// TODO: use empty tuple type.
// TODO: don't create the empty tuple if the call does not get assigned.
context.SetLocal(node_id, context.builder().CreateAlloca(
llvm::StructType::get(context.llvm_context()),
/*ArraySize=*/nullptr, "call.result"));

if (llvm_function->getReturnType()->isVoidTy()) {
context.builder().CreateCall(llvm_function, args);
// TODO: A function with a void return type shouldn't be referred to by
// other nodes.
context.SetLocal(node_id,
llvm::UndefValue::get(context.GetType(node.type_id())));
} else {
context.SetLocal(node_id, context.builder().CreateCall(
function, args, function->getName()));
context.SetLocal(node_id,
context.builder().CreateCall(llvm_function, args,
llvm_function->getName()));
}
}

Expand Down Expand Up @@ -222,6 +233,12 @@ auto LoweringHandleNamespace(LoweringFunctionContext& /*context*/,
// No action to take.
}

auto LoweringHandleNoOp(LoweringFunctionContext& /*context*/,
SemIR::NodeId /*node_id*/, SemIR::Node /*node*/)
-> void {
// No action to take.
}

auto LoweringHandleParameter(LoweringFunctionContext& /*context*/,
SemIR::NodeId /*node_id*/, SemIR::Node /*node*/)
-> void {
Expand Down
22 changes: 22 additions & 0 deletions toolchain/lowering/lowering_handle_expression_category.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
// Exceptions. See /LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "toolchain/lowering/lowering_function_context.h"

namespace Carbon {

auto LoweringHandleBindValue(LoweringFunctionContext& context,
SemIR::NodeId node_id, SemIR::Node node) -> void {
context.SetLocal(node_id, context.GetLocalLoaded(node.GetAsBindValue()));
}

auto LoweringHandleMaterializeTemporary(LoweringFunctionContext& context,
SemIR::NodeId node_id, SemIR::Node node)
-> void {
context.SetLocal(
node_id, context.builder().CreateAlloca(context.GetType(node.type_id()),
nullptr, "temp"));
}

} // namespace Carbon
25 changes: 14 additions & 11 deletions toolchain/lowering/testdata/array/assign_return_value.carbon
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,30 @@ fn Run() {
// CHECK:STDOUT: ; ModuleID = 'assign_return_value.carbon'
// CHECK:STDOUT: source_filename = "assign_return_value.carbon"
// CHECK:STDOUT:
// CHECK:STDOUT: define { i32, i32 } @F() {
// CHECK:STDOUT: define void @F(ptr sret({ i32, i32 }) %return) {
// CHECK:STDOUT: %tuple = alloca { i32, i32 }, align 8
// CHECK:STDOUT: %1 = getelementptr inbounds { i32, i32 }, ptr %tuple, i32 0, i32 0
// CHECK:STDOUT: store i32 12, ptr %1, align 4
// CHECK:STDOUT: %2 = getelementptr inbounds { i32, i32 }, ptr %tuple, i32 0, i32 1
// CHECK:STDOUT: store i32 24, ptr %2, align 4
// CHECK:STDOUT: %3 = load { i32, i32 }, ptr %tuple, align 4
// CHECK:STDOUT: ret { i32, i32 } %3
// CHECK:STDOUT: store { i32, i32 } %3, ptr %return, align 4
// CHECK:STDOUT: ret void
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: define void @Run() {
// CHECK:STDOUT: %t = alloca [2 x i32], align 4
// CHECK:STDOUT: %F = call { i32, i32 } @F()
// CHECK:STDOUT: %temp = alloca { i32, i32 }, align 8
// CHECK:STDOUT: call void @F(ptr %temp)
// CHECK:STDOUT: %1 = load { i32, i32 }, ptr %temp, align 4
// CHECK:STDOUT: %array = alloca [2 x i32], align 4
// CHECK:STDOUT: %array.element = extractvalue { i32, i32 } %F, 0
// CHECK:STDOUT: %1 = getelementptr inbounds [2 x i32], ptr %array, i32 0, i32 0
// CHECK:STDOUT: store i32 %array.element, ptr %1, align 4
// CHECK:STDOUT: %array.element1 = extractvalue { i32, i32 } %F, 1
// CHECK:STDOUT: %2 = getelementptr inbounds [2 x i32], ptr %array, i32 0, i32 1
// CHECK:STDOUT: store i32 %array.element1, ptr %2, align 4
// CHECK:STDOUT: %3 = load [2 x i32], ptr %array, align 4
// CHECK:STDOUT: store [2 x i32] %3, ptr %t, align 4
// CHECK:STDOUT: %array.element = extractvalue { i32, i32 } %1, 0
// CHECK:STDOUT: %2 = getelementptr inbounds [2 x i32], ptr %array, i32 0, i32 0
// CHECK:STDOUT: store i32 %array.element, ptr %2, align 4
// CHECK:STDOUT: %array.element1 = extractvalue { i32, i32 } %1, 1
// CHECK:STDOUT: %3 = getelementptr inbounds [2 x i32], ptr %array, i32 0, i32 1
// CHECK:STDOUT: store i32 %array.element1, ptr %3, align 4
// CHECK:STDOUT: %4 = load [2 x i32], ptr %array, align 4
// CHECK:STDOUT: store [2 x i32] %4, ptr %t, align 4
// CHECK:STDOUT: ret void
// CHECK:STDOUT: }
24 changes: 11 additions & 13 deletions toolchain/lowering/testdata/array/base.carbon
Original file line number Diff line number Diff line change
Expand Up @@ -105,20 +105,18 @@ fn Run() {
// CHECK:STDOUT: %34 = load { i32, i32, i32 }, ptr %tuple19, align 4
// CHECK:STDOUT: store { i32, i32, i32 } %34, ptr %d, align 4
// CHECK:STDOUT: %e = alloca [3 x i32], align 4
// CHECK:STDOUT: %35 = load { i32, i32, i32 }, ptr %d, align 4
// CHECK:STDOUT: %array20 = alloca [3 x i32], align 4
// CHECK:STDOUT: %array.element21 = getelementptr inbounds { i32, i32, i32 }, ptr %d, i32 0, i32 0
// CHECK:STDOUT: %35 = load i32, ptr %array.element21, align 4
// CHECK:STDOUT: %array.element21 = extractvalue { i32, i32, i32 } %35, 0
// CHECK:STDOUT: %36 = getelementptr inbounds [3 x i32], ptr %array20, i32 0, i32 0
// CHECK:STDOUT: store i32 %35, ptr %36, align 4
// CHECK:STDOUT: %array.element22 = getelementptr inbounds { i32, i32, i32 }, ptr %d, i32 0, i32 1
// CHECK:STDOUT: %37 = load i32, ptr %array.element22, align 4
// CHECK:STDOUT: %38 = getelementptr inbounds [3 x i32], ptr %array20, i32 0, i32 1
// CHECK:STDOUT: store i32 %37, ptr %38, align 4
// CHECK:STDOUT: %array.element23 = getelementptr inbounds { i32, i32, i32 }, ptr %d, i32 0, i32 2
// CHECK:STDOUT: %39 = load i32, ptr %array.element23, align 4
// CHECK:STDOUT: %40 = getelementptr inbounds [3 x i32], ptr %array20, i32 0, i32 2
// CHECK:STDOUT: store i32 %39, ptr %40, align 4
// CHECK:STDOUT: %41 = load [3 x i32], ptr %array20, align 4
// CHECK:STDOUT: store [3 x i32] %41, ptr %e, align 4
// CHECK:STDOUT: store i32 %array.element21, ptr %36, align 4
// CHECK:STDOUT: %array.element22 = extractvalue { i32, i32, i32 } %35, 1
// CHECK:STDOUT: %37 = getelementptr inbounds [3 x i32], ptr %array20, i32 0, i32 1
// CHECK:STDOUT: store i32 %array.element22, ptr %37, align 4
// CHECK:STDOUT: %array.element23 = extractvalue { i32, i32, i32 } %35, 2
// CHECK:STDOUT: %38 = getelementptr inbounds [3 x i32], ptr %array20, i32 0, i32 2
// CHECK:STDOUT: store i32 %array.element23, ptr %38, align 4
// CHECK:STDOUT: %39 = load [3 x i32], ptr %array20, align 4
// CHECK:STDOUT: store [3 x i32] %39, ptr %e, align 4
// CHECK:STDOUT: ret void
// CHECK:STDOUT: }
10 changes: 4 additions & 6 deletions toolchain/lowering/testdata/basics/type_values.carbon
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@ fn F64() -> type {
// CHECK:STDOUT: ; ModuleID = 'type_values.carbon'
// CHECK:STDOUT: source_filename = "type_values.carbon"
// CHECK:STDOUT:
// CHECK:STDOUT: %type = type {}
// CHECK:STDOUT:
// CHECK:STDOUT: define %type @I32() {
// CHECK:STDOUT: ret %type zeroinitializer
// CHECK:STDOUT: define void @I32() {
// CHECK:STDOUT: ret void
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: define %type @F64() {
// CHECK:STDOUT: ret %type zeroinitializer
// CHECK:STDOUT: define void @F64() {
// CHECK:STDOUT: ret void
// CHECK:STDOUT: }
8 changes: 3 additions & 5 deletions toolchain/lowering/testdata/function/call/empty_struct.carbon
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,16 @@ fn Main() {
// CHECK:STDOUT: ; ModuleID = 'empty_struct.carbon'
// CHECK:STDOUT: source_filename = "empty_struct.carbon"
// CHECK:STDOUT:
// CHECK:STDOUT: define {} @Echo({} %a) {
// CHECK:STDOUT: define void @Echo({} %a) {
// CHECK:STDOUT: %struct = alloca {}, align 8
// CHECK:STDOUT: %1 = load {}, ptr %struct, align 1
// CHECK:STDOUT: ret {} %1
// CHECK:STDOUT: ret void
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: define void @Main() {
// CHECK:STDOUT: %struct = alloca {}, align 8
// CHECK:STDOUT: %b = alloca {}, align 8
// CHECK:STDOUT: %struct1 = alloca {}, align 8
// CHECK:STDOUT: %1 = load {}, ptr %struct1, align 1
// CHECK:STDOUT: %Echo = call {} @Echo({} %1)
// CHECK:STDOUT: store {} %Echo, ptr %b, align 1
// CHECK:STDOUT: call void @Echo({} %1)
// CHECK:STDOUT: ret void
// CHECK:STDOUT: }
7 changes: 3 additions & 4 deletions toolchain/lowering/testdata/function/call/empty_tuple.carbon
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,15 @@ fn Main() {
// CHECK:STDOUT: ; ModuleID = 'empty_tuple.carbon'
// CHECK:STDOUT: source_filename = "empty_tuple.carbon"
// CHECK:STDOUT:
// CHECK:STDOUT: define {} @Echo({} %a) {
// CHECK:STDOUT: ret {} %a
// CHECK:STDOUT: define void @Echo({} %a) {
// CHECK:STDOUT: ret void
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: define void @Main() {
// CHECK:STDOUT: %tuple = alloca {}, align 8
// CHECK:STDOUT: %b = alloca {}, align 8
// CHECK:STDOUT: %tuple1 = alloca {}, align 8
// CHECK:STDOUT: %1 = load {}, ptr %tuple1, align 1
// CHECK:STDOUT: %Echo = call {} @Echo({} %1)
// CHECK:STDOUT: store {} %Echo, ptr %b, align 1
// CHECK:STDOUT: call void @Echo({} %1)
// CHECK:STDOUT: ret void
// CHECK:STDOUT: }
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,16 @@ fn Main() {
// CHECK:STDOUT: ret void
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: define {} @Bar({} %a) {
// CHECK:STDOUT: ret {} %a
// CHECK:STDOUT: define void @Bar({} %a) {
// CHECK:STDOUT: ret void
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: define void @Main() {
// CHECK:STDOUT: %tuple = alloca {}, align 8
// CHECK:STDOUT: %x = alloca {}, align 8
// CHECK:STDOUT: call void @Foo()
// CHECK:STDOUT: %call.result = alloca {}, align 8
// CHECK:STDOUT: %1 = load {}, ptr %call.result, align 1
// CHECK:STDOUT: %Bar = call {} @Bar({} %1)
// CHECK:STDOUT: store {} %Bar, ptr %x, align 1
// CHECK:STDOUT: %temp = alloca {}, align 8
// CHECK:STDOUT: %1 = load {}, ptr %temp, align 1
// CHECK:STDOUT: call void @Bar({} %1)
// CHECK:STDOUT: ret void
// CHECK:STDOUT: }
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,5 @@ fn Main() {
// CHECK:STDOUT:
// CHECK:STDOUT: define void @Main() {
// CHECK:STDOUT: call void @Foo(i32 1)
// CHECK:STDOUT: %call.result = alloca {}, align 8
// CHECK:STDOUT: ret void
// CHECK:STDOUT: }
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ fn Main() {
// CHECK:STDOUT:
// CHECK:STDOUT: define void @Main() {
// CHECK:STDOUT: call void @Foo(i32 1)
// CHECK:STDOUT: %call.result = alloca {}, align 8
// CHECK:STDOUT: call void @Foo(i32 1)
// CHECK:STDOUT: %call.result1 = alloca {}, align 8
// CHECK:STDOUT: ret void
// CHECK:STDOUT: }
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,5 @@ fn Main() {
// CHECK:STDOUT:
// CHECK:STDOUT: define void @Main() {
// CHECK:STDOUT: call void @Foo(i32 1, i32 2)
// CHECK:STDOUT: %call.result = alloca {}, align 8
// CHECK:STDOUT: ret void
// CHECK:STDOUT: }
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ fn Main() {
// CHECK:STDOUT:
// CHECK:STDOUT: define void @Main() {
// CHECK:STDOUT: call void @Foo(i32 1, i32 2)
// CHECK:STDOUT: %call.result = alloca {}, align 8
// CHECK:STDOUT: call void @Foo(i32 1, i32 2)
// CHECK:STDOUT: %call.result1 = alloca {}, align 8
// CHECK:STDOUT: ret void
// CHECK:STDOUT: }
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,5 @@ fn Main() {
// CHECK:STDOUT:
// CHECK:STDOUT: define void @Main() {
// CHECK:STDOUT: call void @Foo()
// CHECK:STDOUT: %call.result = alloca {}, align 8
// CHECK:STDOUT: ret void
// CHECK:STDOUT: }
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,5 @@ fn Main() {
// CHECK:STDOUT: %tuple = alloca {}, align 8
// CHECK:STDOUT: %b = alloca {}, align 8
// CHECK:STDOUT: call void @MakeImplicitEmptyTuple()
// CHECK:STDOUT: %call.result = alloca {}, align 8
// CHECK:STDOUT: %1 = load {}, ptr %call.result, align 1
// CHECK:STDOUT: store {} %1, ptr %b, align 1
// CHECK:STDOUT: ret void
// CHECK:STDOUT: }
1 change: 0 additions & 1 deletion toolchain/lowering/testdata/function/call/var_param.carbon
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,5 @@ fn Main() {
// CHECK:STDOUT: store i32 0, ptr %a, align 4
// CHECK:STDOUT: %1 = load i32, ptr %a, align 4
// CHECK:STDOUT: call void @DoNothing(i32 %1)
// CHECK:STDOUT: %call.result = alloca {}, align 8
// CHECK:STDOUT: ret void
// CHECK:STDOUT: }
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,5 @@ fn G(n: i32) { F(n); }
// CHECK:STDOUT:
// CHECK:STDOUT: define void @G(i32 %n) {
// CHECK:STDOUT: call void @F(i32 %n)
// CHECK:STDOUT: %call.result = alloca {}, align 8
// CHECK:STDOUT: ret void
// CHECK:STDOUT: }

0 comments on commit 1013d17

Please sign in to comment.