diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index e7073aad05263a..e85cb1232ae8d3 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -111,6 +111,7 @@ filegroup( "runtime_matmul_f64.cc", "runtime_matmul_s32.cc", "runtime_fork_join.cc", + #"runtime_handle_ffi_call.cc", ], visibility = internal_visibility([":friends"]), ) @@ -138,6 +139,7 @@ filegroup( "runtime_fork_join.h", "runtime_lightweight_check.h", "runtime_matmul.h", + #"runtime_handle_ffi_call.h", ], visibility = internal_visibility([":friends"]), ) @@ -492,6 +494,7 @@ cc_library( ":runtime_fft", ":runtime_fork_join", ":runtime_fp16", + ":runtime_handle_ffi_call", ":runtime_key_value_sort", ":runtime_matmul", ":runtime_matmul_acl", @@ -651,6 +654,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -1184,6 +1188,33 @@ cc_library( ], ) +cc_library( + name = "runtime_handle_ffi_call", + srcs = ["runtime_handle_ffi_call.cc"], + hdrs = ["runtime_handle_ffi_call.h"], + copts = runtime_copts(), + visibility = ["//visibility:public"], + deps = [ + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/ffi:call_frame", + "//xla/ffi:ffi_api", + "//xla/service:custom_call_status_public_headers", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + ], +) + xla_cc_test( name = "cpu_runtime_test", srcs = ["cpu_runtime_test.cc"], diff --git a/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc b/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc new file mode 100644 index 00000000000000..a962c998d3d827 --- /dev/null +++ b/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc @@ -0,0 +1,244 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/cpu/runtime_handle_ffi_call.h" + +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/attributes.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/AsmParser/AsmParser.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "xla/ffi/call_frame.h" +#include "xla/ffi/ffi_api.h" +#include "xla/primitive_util.h" +#include "xla/service/custom_call_status.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" + +namespace ffi = xla::ffi; + +namespace { + +using Attribute = ffi::CallFrameBuilder::FlatAttribute; +using AttributesMap = ffi::CallFrameBuilder::FlatAttributesMap; + +// TODO(heinsaar): This BuildAttributesMap() is originally an identical +// copy-paste of the same function in custom_call_thunk.cc +// May make sense to have one in a common place & reuse. +absl::StatusOr BuildAttributesMap(mlir::DictionaryAttr dict) { + AttributesMap attributes; + for (auto& kv : dict) { + std::string_view name = kv.getName().strref(); + + auto integer = [&](mlir::IntegerAttr integer) { + switch (integer.getType().getIntOrFloatBitWidth()) { + case 32: + attributes[name] = static_cast(integer.getInt()); + return absl::OkStatus(); + case 64: + attributes[name] = static_cast(integer.getInt()); + return absl::OkStatus(); + default: + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported integer attribute bit width for attribute: ", name)); + } + }; + + auto fp = [&](mlir::FloatAttr fp) { + switch (fp.getType().getIntOrFloatBitWidth()) { + case 32: + attributes[name] = static_cast(fp.getValue().convertToFloat()); + return absl::OkStatus(); + default: + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported float attribute bit width for attribute: ", name)); + } + }; + + auto str = [&](mlir::StringAttr str) { + attributes[name] = str.getValue().str(); + return absl::OkStatus(); + }; + + TF_RETURN_IF_ERROR( + llvm::TypeSwitch(kv.getValue()) + .Case(integer) + .Case(fp) + .Case(str) + .Default([&](mlir::Attribute) { + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported attribute type for attribute: ", name)); + })); + } + return attributes; +} + +absl::Span DecodeDims(int64_t* encoded_dims_data) { + auto dims_count = encoded_dims_data[0]; + auto dims_begin = encoded_dims_data + 1; + return absl::MakeSpan(dims_begin, dims_begin + dims_count); +} + +// TODO(heinsaar): Once on C++20, this can (and should) be a local lambda with +// an explicit template parameter list. +class ArgInserter { + public: + template + explicit ArgInserter(B&& b) : b_(std::forward(b)) {} + + template + void operator()(Args&&... args) const { + b_.AddBufferArg(std::forward(args)...); + } + + private: + ffi::CallFrameBuilder& b_; +}; + +// TODO(heinsaar): Once on C++20, this can (and should) be a local lambda with +// an explicit template parameter list. +class RetInserter { + public: + template + explicit RetInserter(B&& b) : b_(std::forward(b)) {} + + template + void operator()(Args&&... args) const { + b_.AddBufferRet(std::forward(args)...); + } + + private: + ffi::CallFrameBuilder& b_; +}; + +template +void BuildBuffers(absl::Span types, int64_t* encoded_dims, + absl::Span address_space, Builder&& builder) { + int64_t dim_pos = 0; + for (int64_t i = 0; i < types.size(); ++i) { + auto dtype = static_cast(types[i]); + auto dims = DecodeDims(encoded_dims + dim_pos); + auto elem_count = absl::c_accumulate(dims, 1, std::multiplies()); + auto data_width = xla::primitive_util::ByteWidth(dtype) * elem_count; + + builder(tensorflow::se::DeviceMemoryBase(address_space[i], data_width), + /*type = */ dtype, + /*dims = */ dims); + dim_pos += 1; // Jumps over count value + dim_pos += dims.size(); // Jumps over all dimensions in a shape + } +} + +inline absl::Status BuildAndCallFfi( + std::string_view target_name, std::string_view backend_config, + absl::Span outputs, absl::Span inputs, + absl::Span result_types, int64_t* result_dims, + absl::Span operand_types, int64_t* operand_dims) { + CHECK_EQ(outputs.size(), result_types.size()); + CHECK_EQ(inputs.size(), operand_types.size()); + + if (absl::c_any_of(operand_types, [](int32_t type) { + return static_cast(type) == + xla::PrimitiveType::TUPLE; + })) { + return absl::InternalError( + "Tuple operands are not supported yet in typed FFI custom calls."); + } + + // Find the registered FFI handler for this custom call target. + absl::StatusOr registration = + ffi::FindHandler(target_name, "Host"); + + if (!registration.ok()) { + return absl::UnimplementedError( + absl::StrCat("No registered implementation for custom call to ", + target_name, " for Host.")); + } + + // For FFI handlers backend config must be a compatible MLIR dictionary. + mlir::MLIRContext mlir_context; + ffi::CallFrameBuilder::FlatAttributesMap attributes; + if (!backend_config.empty()) { + // Backend config not empty, so proceed to parse it into an MLIR attribute + // and build an MLIR compatible map of attributes out of it. + mlir::Attribute attr = mlir::parseAttribute(backend_config, &mlir_context); + if (auto dict = attr.dyn_cast_or_null()) { + TF_ASSIGN_OR_RETURN(attributes, BuildAttributesMap(dict)); + } else { + return absl::InternalError( + "Unsupported backend config. Expected a string parsable into " + "dictionary attribute"); + } + } + + ffi::CallFrameBuilder builder; + + // Forward the constructed attributes to the call frame + ffi::CallFrameBuilder::AttributesBuilder attrs; + attrs.Append(std::move(attributes)); + builder.AddAttributes(attrs.Build()); + + // Decode dimensions metadata into shapes and build operand & result buffers + BuildBuffers(operand_types, operand_dims, inputs, ArgInserter(builder)); + BuildBuffers(result_types, result_dims, outputs, RetInserter(builder)); + + ffi::CallFrame call_frame = builder.Build(); + return ffi::Call(registration->handler, call_frame); // Status +} + +} // namespace + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_HandleFfiCall( + const char* target_name_ptr, int64_t target_name_len, void* output, + void** inputs, const char* opaque_str_ptr, int64_t opaque_str_len, + void* status_opaque, int32_t* operand_types, int64_t operand_count, + int64_t* operand_dims, int32_t* result_types, int64_t result_count, + int64_t* result_dims) { + auto target_name = absl::string_view(target_name_ptr, target_name_len); + auto backend_config = absl::string_view(opaque_str_ptr, opaque_str_len); + auto xla_status = reinterpret_cast(status_opaque); + + void** outputs = &output; + if (result_count > 1) { // output is a tuple + outputs = reinterpret_cast(output); + } + + absl::Status status = BuildAndCallFfi( + target_name, backend_config, absl::MakeSpan(outputs, result_count), + absl::MakeSpan(inputs, operand_count), + absl::MakeSpan(result_types, result_count), result_dims, + absl::MakeSpan(operand_types, operand_count), operand_dims); + + if (!status.ok()) { + // In the future, status propagation will likely be possible. + // However, currently this has to pass through XlaCustomCallStatus + // which lacks functionality for status codes (it is fixed on INTERNAL) + XlaCustomCallStatusSetFailure(xla_status, status.message().data(), + status.message().size()); + } +} diff --git a/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.h b/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.h new file mode 100644 index 00000000000000..e8afb236b2bb26 --- /dev/null +++ b/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.h @@ -0,0 +1,32 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_CPU_RUNTIME_HANDLE_FFI_CALL_H_ +#define XLA_SERVICE_CPU_RUNTIME_HANDLE_FFI_CALL_H_ + +#include + +extern "C" { + +extern void __xla_cpu_runtime_HandleFfiCall( + const char* target_name_ptr, int64_t target_name_len, void* output, + void** inputs, const char* opaque_str_ptr, int64_t opaque_str_len, + void* status_opaque, int32_t* operand_types, int64_t operand_count, + int64_t* operand_dims, int32_t* result_types, int64_t result_count, + int64_t* result_dims); + +} // extern "C" + +#endif // XLA_SERVICE_CPU_RUNTIME_HANDLE_FFI_CALL_H_