Skip to content

Commit

Permalink
Completed implementation for functions of mlx.cpp
Browse files Browse the repository at this point in the history
Signed-off-by: Aryan Gupta <[email protected]>
  • Loading branch information
guptaaryan16 committed May 22, 2024
1 parent ade0c3a commit 84da6d7
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 10 deletions.
9 changes: 8 additions & 1 deletion plugins/wasi_nn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ wasmedge_add_library(wasmedgePluginWasiNN
torch.cpp
tfl.cpp
ggml.cpp
# mlx.cpp
mlx.cpp
)

target_compile_options(wasmedgePluginWasiNN
Expand All @@ -178,6 +178,13 @@ target_include_directories(wasmedgePluginWasiNN
${CMAKE_CURRENT_SOURCE_DIR}
)

if (BACKEND STREQL "mlx")
find_package(MLX CONFIG REQUIRED)
target_link_libraries(wasmedgePluginWasiNN PUBLIC mlx)
# Add path for including mlx_llm
endif()


if(BACKEND STREQUAL "ggml")
# Setup llava from llama.cpp
wasmedge_add_library(llava OBJECT
Expand Down
102 changes: 94 additions & 8 deletions plugins/wasi_nn/mlx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,16 @@

#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX
#include <mlx/mlx.h>
#include <mlx_llm/llm.h>
#endif

namespace WasmEdge::Host::WASINN::MLX {
#ifdef WASMEDGE_PLUGIN_WASI_NN_BACKEND_MLX

// TODO: Implementation MLX adds C++ API NN support
// Track Github branch: https://github.com/guptaaryan16/mlx/tree/Cpp_api
// TODO: Implementation MLX_LLM adds C++ API NN support
// Track Github branch: https://github.com/guptaaryan16/mlx_llm.cpp
// TODO: Add similar API as llama.cpp for loading LLAMA models in MLX
// struct LLM_MODEL :: mlx::core::nn::Module {
// StreamOrDevice device = metal::is_available() ? Device::gpu : Device::cpu;
// // Dummy forward method for inference
// };
// Right Now, we have a test model to understand the usage of the API

Expect<ErrNo> load(WasiNNEnvironment &Env, Span<const Span<uint8_t>> Builders,
Device Device, uint32_t &GraphId) noexcept {
Expand Down Expand Up @@ -82,17 +80,105 @@ Expect<ErrNo> load(WasiNNEnvironment &Env, Span<const Span<uint8_t>> Builders,
GraphRef.ModelFilePath = ModelFilePath;

try {
auto{GraphRef.ModelWeights, GraphRef.ModelMetadata} =
mlx::core::load_safetensors(ModelFilePath);
// TODO:: Replace the TestModel with LLM API after mlx_llm gets completed
auto GraphRef.Model = mlx::core::nn::TestModel();
GraphRef.Model.load_weights(ModelFilePath);
} catch (const c10::Error &e) {
spdlog::error("[WASI-NN] Failed when load the MLX model.");
Env.NNGraph.pop_back();
return ErrNo::InvalidArgument;
}
// Store the loaded graph.
GraphId = Env.NNGraph.size() - 1;
GraphRef.GraphId = GraphId;
return ErrNo::Success;
}

Expect<ErrNo> initExecCtx(WasiNNEnvironment &Env, uint32_t GraphId,
uint32_t &ContextId) noexcept {
Env.NNContext.emplace_back(GraphId, Env.NNGraph[GraphId]);

ContextId = Env.NNContext.size() - 1;
return ErrNo::Success;
}

Expect<ErrNo> setInput(WasiNNEnvironment &Env, uint32_t ContextId,
uint32_t Index, const mlx::core::array &InputArray) noexcept {
auto &CxtRef = Env.NNContext[ContextId].get<Context>();
if (Index >= CxtRef.TorchInputs.size()) {
CxtRef.TorchInputs.resize(Index + 1);
}
if (InputArray.dtype != mlx::core::float32) {
spdlog::error(
"[WASI-NN] Only F32 inputs and outputs are supported for now.");
return ErrNo::InvalidArgument;
}
std::vector<int64_t> Dims;
for (size_t I = 0; I < InputArray.shape().size(); I++) {
Dims.push_back(static_cast<int64_t>(InputArray.Dimension[I]));
}
auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get<Graph>();
mlx::core::array InTensor = mlx::core::array(
reinterpret_cast<float *>(InputArray.item()), Dims, mlx::core::float32, mlx::core::to_stream());

CxtRef.MLXInputs[Index] = InTensor.copy();
return ErrNo::Success;
}
Expect<ErrNo> getOutput(WasiNNEnvironment &Env, uint32_t ContextId,
uint32_t Index, Span<uint8_t> OutBuffer,
uint32_t &BytesWritten) noexcept {
auto &CxtRef = Env.NNContext[ContextId].get<Context>();
if (CxtRef.Outputs.size() <= Index) {
spdlog::error(
"[WASI-NN] The output index {} exceeds the outputs number {}.", Index,
CxtRef.Outputs.size());
return ErrNo::InvalidArgument;
}
torch::Tensor Tensor =
CxtRef.Outputs[Index].to(at::kCPU).toType(torch::kFloat32);
float *TensorBuffer = OutTensor.data_ptr<float>();

size_t BlobSize = 1;
for (auto I : OutTensor.sizes()) {
BlobSize *= I;
}
uint32_t BytesToWrite =
std::min(static_cast<size_t>(BlobSize * 4), OutBuffer.size());
std::copy_n(reinterpret_cast<const uint8_t *>(TensorBuffer), BytesToWrite,
OutBuffer.data());
BytesWritten = BytesToWrite;
return ErrNo::Success;
}
Expect<ErrNo> compute(WasiNNEnvironment &Env, uint32_t ContextId) noexcept {
auto &CxtRef = Env.NNContext[ContextId].get<Context>();
if (CxtRef.MLXInputs.size() == 0) {
spdlog::error("[WASI-NN] Input is not set!");
return ErrNo::InvalidArgument;
}
for (size_t I = 0; I < CxtRef.MLXInputs.shape().size(); I++) {
mlx::core::array InTensor = CxtRef.TorchInputs[I];
if (InTensor.isNone()) {
spdlog::error("[WASI-NN] Input [{}] is not set!", I);
return ErrNo::InvalidArgument;
}
}
auto &GraphRef = Env.NNGraph[CxtRef.GraphId].get<Graph>();
mlx::core::array RawOutput =
GraphRef.MLXModel.forward(CxtRef.MLXInputs);
// TODO: Output does not seem correct for now
// if (mlx::core::is_array_v<RawOutput>()) {
// auto OutTensors = mlx::core::array(
// reinterpret_cast<float *>(InputArray.item()), Dims, mlx::core::float32, mlx::core::to_stream());
// for (auto &OneOf : OutTensors) {
// CxtRef.MLXOutputs.push_back(OneOf.clone());
// } else {
// spdlog::error("[WASI-NN] PyTorch backend only supports output a tensor, "
// "a list of tensor or a tuple of tensor");
// return ErrNo::InvalidArgument;
// }
return ErrNo::Success;
}

#else
namespace {
Expect<ErrNo> reportBackendNotSupported() noexcept {
Expand Down
19 changes: 18 additions & 1 deletion plugins/wasi_nn/tfl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,24 @@ Expect<WASINN::ErrNo> load(WASINN::WasiNNEnvironment &Env,
auto &GraphRef = Env.NNGraph.back().get<Graph>();

// Copy graph builder data to TfLiteModData and create a new TfLiteModel.
GraphRef.TfLiteModData.assign(Builders[0].begin(), Builders[0].end());
GraphRef.TfLiteModExpect<WASINN::ErrNo> initExecCtx(
WASINN::WasiNNEnvironment &, uint32_t, uint32_t &) noexcept {
return reportBackendNotSupported();
}
Expect<WASINN::ErrNo> setInput(WASINN::WasiNNEnvironment &, uint32_t,
uint32_t, const TensorData &) noexcept {
return reportBackendNotSupported();
}
Expect<WASINN::ErrNo> getOutput(WASINN::WasiNNEnvironment &, uint32_t,
uint32_t, Span<uint8_t>,
uint32_t &) noexcept {
return reportBackendNotSupported();
}
Expect<WASINN::ErrNo> compute(WASINN::WasiNNEnvironment &,
uint32_t) noexcept {
return reportBackendNotSupported();
}
Data.assign(Builders[0].begin(), Builders[0].end());
GraphRef.TFLiteMod = TfLiteModelCreate(GraphRef.TfLiteModData.data(),
GraphRef.TfLiteModData.size());
if (unlikely(GraphRef.TFLiteMod == nullptr)) {
Expand Down

0 comments on commit 84da6d7

Please sign in to comment.