458 lines
14 KiB
C++
458 lines
14 KiB
C++
|
//===-- Shared memory RPC server instantiation ------------------*- C++ -*-===//
|
||
|
//
|
||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
#include "rpc_server.h"
|
||
|
|
||
|
#include "src/__support/RPC/rpc.h"
|
||
|
#include "src/stdio/gpu/file.h"
|
||
|
#include <atomic>
|
||
|
#include <cstdio>
|
||
|
#include <cstring>
|
||
|
#include <memory>
|
||
|
#include <mutex>
|
||
|
#include <unordered_map>
|
||
|
#include <variant>
|
||
|
#include <vector>
|
||
|
|
||
|
using namespace LIBC_NAMESPACE;
|
||
|
|
||
|
static_assert(sizeof(rpc_buffer_t) == sizeof(rpc::Buffer),
|
||
|
"Buffer size mismatch");
|
||
|
|
||
|
static_assert(RPC_MAXIMUM_PORT_COUNT == rpc::MAX_PORT_COUNT,
|
||
|
"Incorrect maximum port count");
|
||
|
|
||
|
// The client needs to support different lane sizes for the SIMT model. Because
|
||
|
// of this we need to select between the possible sizes that the client can use.
|
||
|
struct Server {
|
||
|
template <uint32_t lane_size>
|
||
|
Server(std::unique_ptr<rpc::Server<lane_size>> &&server)
|
||
|
: server(std::move(server)) {}
|
||
|
|
||
|
rpc_status_t handle_server(
|
||
|
const std::unordered_map<rpc_opcode_t, rpc_opcode_callback_ty> &callbacks,
|
||
|
const std::unordered_map<rpc_opcode_t, void *> &callback_data,
|
||
|
uint32_t &index) {
|
||
|
rpc_status_t ret = RPC_STATUS_SUCCESS;
|
||
|
std::visit(
|
||
|
[&](auto &server) {
|
||
|
ret = handle_server(*server, callbacks, callback_data, index);
|
||
|
},
|
||
|
server);
|
||
|
return ret;
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
template <uint32_t lane_size>
|
||
|
rpc_status_t handle_server(
|
||
|
rpc::Server<lane_size> &server,
|
||
|
const std::unordered_map<rpc_opcode_t, rpc_opcode_callback_ty> &callbacks,
|
||
|
const std::unordered_map<rpc_opcode_t, void *> &callback_data,
|
||
|
uint32_t &index) {
|
||
|
auto port = server.try_open(index);
|
||
|
if (!port)
|
||
|
return RPC_STATUS_SUCCESS;
|
||
|
|
||
|
switch (port->get_opcode()) {
|
||
|
case RPC_WRITE_TO_STREAM:
|
||
|
case RPC_WRITE_TO_STDERR:
|
||
|
case RPC_WRITE_TO_STDOUT:
|
||
|
case RPC_WRITE_TO_STDOUT_NEWLINE: {
|
||
|
uint64_t sizes[lane_size] = {0};
|
||
|
void *strs[lane_size] = {nullptr};
|
||
|
FILE *files[lane_size] = {nullptr};
|
||
|
if (port->get_opcode() == RPC_WRITE_TO_STREAM) {
|
||
|
port->recv([&](rpc::Buffer *buffer, uint32_t id) {
|
||
|
files[id] = reinterpret_cast<FILE *>(buffer->data[0]);
|
||
|
});
|
||
|
} else if (port->get_opcode() == RPC_WRITE_TO_STDERR) {
|
||
|
std::fill(files, files + lane_size, stderr);
|
||
|
} else {
|
||
|
std::fill(files, files + lane_size, stdout);
|
||
|
}
|
||
|
|
||
|
port->recv_n(strs, sizes, [&](uint64_t size) { return new char[size]; });
|
||
|
port->send([&](rpc::Buffer *buffer, uint32_t id) {
|
||
|
flockfile(files[id]);
|
||
|
buffer->data[0] = fwrite_unlocked(strs[id], 1, sizes[id], files[id]);
|
||
|
if (port->get_opcode() == RPC_WRITE_TO_STDOUT_NEWLINE &&
|
||
|
buffer->data[0] == sizes[id])
|
||
|
buffer->data[0] += fwrite_unlocked("\n", 1, 1, files[id]);
|
||
|
funlockfile(files[id]);
|
||
|
delete[] reinterpret_cast<uint8_t *>(strs[id]);
|
||
|
});
|
||
|
break;
|
||
|
}
|
||
|
case RPC_READ_FROM_STREAM: {
|
||
|
uint64_t sizes[lane_size] = {0};
|
||
|
void *data[lane_size] = {nullptr};
|
||
|
port->recv([&](rpc::Buffer *buffer, uint32_t id) {
|
||
|
data[id] = new char[buffer->data[0]];
|
||
|
sizes[id] = fread(data[id], 1, buffer->data[0],
|
||
|
file::to_stream(buffer->data[1]));
|
||
|
});
|
||
|
port->send_n(data, sizes);
|
||
|
port->send([&](rpc::Buffer *buffer, uint32_t id) {
|
||
|
delete[] reinterpret_cast<uint8_t *>(data[id]);
|
||
|
std::memcpy(buffer->data, &sizes[id], sizeof(uint64_t));
|
||
|
});
|
||
|
break;
|
||
|
}
|
||
|
case RPC_READ_FGETS: {
|
||
|
uint64_t sizes[lane_size] = {0};
|
||
|
void *data[lane_size] = {nullptr};
|
||
|
port->recv([&](rpc::Buffer *buffer, uint32_t id) {
|
||
|
data[id] = new char[buffer->data[0]];
|
||
|
const char *str =
|
||
|
fgets(reinterpret_cast<char *>(data[id]), buffer->data[0],
|
||
|
file::to_stream(buffer->data[1]));
|
||
|
sizes[id] = !str ? 0 : std::strlen(str) + 1;
|
||
|
});
|
||
|
port->send_n(data, sizes);
|
||
|
for (uint32_t id = 0; id < lane_size; ++id)
|
||
|
if (data[id])
|
||
|
delete[] reinterpret_cast<uint8_t *>(data[id]);
|
||
|
break;
|
||
|
}
|
||
|
case RPC_OPEN_FILE: {
|
||
|
uint64_t sizes[lane_size] = {0};
|
||
|
void *paths[lane_size] = {nullptr};
|
||
|
port->recv_n(paths, sizes, [&](uint64_t size) { return new char[size]; });
|
||
|
port->recv_and_send([&](rpc::Buffer *buffer, uint32_t id) {
|
||
|
FILE *file = fopen(reinterpret_cast<char *>(paths[id]),
|
||
|
reinterpret_cast<char *>(buffer->data));
|
||
|
buffer->data[0] = reinterpret_cast<uintptr_t>(file);
|
||
|
});
|
||
|
break;
|
||
|
}
|
||
|
case RPC_CLOSE_FILE: {
|
||
|
port->recv_and_send([&](rpc::Buffer *buffer, uint32_t id) {
|
||
|
FILE *file = reinterpret_cast<FILE *>(buffer->data[0]);
|
||
|
buffer->data[0] = fclose(file);
|
||
|
});
|
||
|
break;
|
||
|
}
|
||
|
case RPC_EXIT: {
|
||
|
// Send a response to the client to signal that we are ready to exit.
|
||
|
port->recv_and_send([](rpc::Buffer *) {});
|
||
|
port->recv([](rpc::Buffer *buffer) {
|
||
|
int status = 0;
|
||
|
std::memcpy(&status, buffer->data, sizeof(int));
|
||
|
exit(status);
|
||
|
});
|
||
|
break;
|
||
|
}
|
||
|
case RPC_ABORT: {
|
||
|
// Send a response to the client to signal that we are ready to abort.
|
||
|
port->recv_and_send([](rpc::Buffer *) {});
|
||
|
port->recv([](rpc::Buffer *) {});
|
||
|
abort();
|
||
|
break;
|
||
|
}
|
||
|
case RPC_HOST_CALL: {
|
||
|
uint64_t sizes[lane_size] = {0};
|
||
|
void *args[lane_size] = {nullptr};
|
||
|
port->recv_n(args, sizes, [&](uint64_t size) { return new char[size]; });
|
||
|
port->recv([&](rpc::Buffer *buffer, uint32_t id) {
|
||
|
reinterpret_cast<void (*)(void *)>(buffer->data[0])(args[id]);
|
||
|
});
|
||
|
port->send([&](rpc::Buffer *, uint32_t id) {
|
||
|
delete[] reinterpret_cast<uint8_t *>(args[id]);
|
||
|
});
|
||
|
break;
|
||
|
}
|
||
|
case RPC_FEOF: {
|
||
|
port->recv_and_send([](rpc::Buffer *buffer) {
|
||
|
buffer->data[0] = feof(file::to_stream(buffer->data[0]));
|
||
|
});
|
||
|
break;
|
||
|
}
|
||
|
case RPC_FERROR: {
|
||
|
port->recv_and_send([](rpc::Buffer *buffer) {
|
||
|
buffer->data[0] = ferror(file::to_stream(buffer->data[0]));
|
||
|
});
|
||
|
break;
|
||
|
}
|
||
|
case RPC_CLEARERR: {
|
||
|
port->recv_and_send([](rpc::Buffer *buffer) {
|
||
|
clearerr(file::to_stream(buffer->data[0]));
|
||
|
});
|
||
|
break;
|
||
|
}
|
||
|
case RPC_FSEEK: {
|
||
|
port->recv_and_send([](rpc::Buffer *buffer) {
|
||
|
buffer->data[0] = fseek(file::to_stream(buffer->data[0]),
|
||
|
static_cast<long>(buffer->data[1]),
|
||
|
static_cast<int>(buffer->data[2]));
|
||
|
});
|
||
|
break;
|
||
|
}
|
||
|
case RPC_FTELL: {
|
||
|
port->recv_and_send([](rpc::Buffer *buffer) {
|
||
|
buffer->data[0] = ftell(file::to_stream(buffer->data[0]));
|
||
|
});
|
||
|
break;
|
||
|
}
|
||
|
case RPC_FFLUSH: {
|
||
|
port->recv_and_send([](rpc::Buffer *buffer) {
|
||
|
buffer->data[0] = fflush(file::to_stream(buffer->data[0]));
|
||
|
});
|
||
|
break;
|
||
|
}
|
||
|
case RPC_UNGETC: {
|
||
|
port->recv_and_send([](rpc::Buffer *buffer) {
|
||
|
buffer->data[0] = ungetc(static_cast<int>(buffer->data[0]),
|
||
|
file::to_stream(buffer->data[1]));
|
||
|
});
|
||
|
break;
|
||
|
}
|
||
|
case RPC_NOOP: {
|
||
|
port->recv([](rpc::Buffer *) {});
|
||
|
break;
|
||
|
}
|
||
|
default: {
|
||
|
auto handler =
|
||
|
callbacks.find(static_cast<rpc_opcode_t>(port->get_opcode()));
|
||
|
|
||
|
// We error out on an unhandled opcode.
|
||
|
if (handler == callbacks.end())
|
||
|
return RPC_STATUS_UNHANDLED_OPCODE;
|
||
|
|
||
|
// Invoke the registered callback with a reference to the port.
|
||
|
void *data =
|
||
|
callback_data.at(static_cast<rpc_opcode_t>(port->get_opcode()));
|
||
|
rpc_port_t port_ref{reinterpret_cast<uint64_t>(&*port), lane_size};
|
||
|
(handler->second)(port_ref, data);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Increment the index so we start the scan after this port.
|
||
|
index = port->get_index() + 1;
|
||
|
port->close();
|
||
|
return RPC_STATUS_CONTINUE;
|
||
|
}
|
||
|
|
||
|
std::variant<std::unique_ptr<rpc::Server<1>>,
|
||
|
std::unique_ptr<rpc::Server<32>>,
|
||
|
std::unique_ptr<rpc::Server<64>>>
|
||
|
server;
|
||
|
};
|
||
|
|
||
|
struct Device {
|
||
|
template <typename T>
|
||
|
Device(uint32_t num_ports, void *buffer, std::unique_ptr<T> &&server)
|
||
|
: buffer(buffer), server(std::move(server)), client(num_ports, buffer) {}
|
||
|
void *buffer;
|
||
|
Server server;
|
||
|
rpc::Client client;
|
||
|
std::unordered_map<rpc_opcode_t, rpc_opcode_callback_ty> callbacks;
|
||
|
std::unordered_map<rpc_opcode_t, void *> callback_data;
|
||
|
};
|
||
|
|
||
|
// A struct containing all the runtime state required to run the RPC server.
|
||
|
struct State {
|
||
|
State(uint32_t num_devices)
|
||
|
: num_devices(num_devices), devices(num_devices), reference_count(0u) {}
|
||
|
uint32_t num_devices;
|
||
|
std::vector<std::unique_ptr<Device>> devices;
|
||
|
std::atomic_uint32_t reference_count;
|
||
|
};
|
||
|
|
||
|
static std::mutex startup_mutex;
|
||
|
|
||
|
static State *state;
|
||
|
|
||
|
rpc_status_t rpc_init(uint32_t num_devices) {
|
||
|
std::scoped_lock<decltype(startup_mutex)> lock(startup_mutex);
|
||
|
if (!state)
|
||
|
state = new State(num_devices);
|
||
|
|
||
|
if (state->reference_count == std::numeric_limits<uint32_t>::max())
|
||
|
return RPC_STATUS_ERROR;
|
||
|
|
||
|
state->reference_count++;
|
||
|
|
||
|
return RPC_STATUS_SUCCESS;
|
||
|
}
|
||
|
|
||
|
rpc_status_t rpc_shutdown(void) {
|
||
|
if (state && state->reference_count-- == 1)
|
||
|
delete state;
|
||
|
|
||
|
return RPC_STATUS_SUCCESS;
|
||
|
}
|
||
|
|
||
|
template <uint32_t lane_size>
|
||
|
rpc_status_t server_init_impl(uint32_t device_id, uint64_t num_ports,
|
||
|
rpc_alloc_ty alloc, void *data) {
|
||
|
uint64_t size = rpc::Server<lane_size>::allocation_size(num_ports);
|
||
|
void *buffer = alloc(size, data);
|
||
|
|
||
|
if (!buffer)
|
||
|
return RPC_STATUS_ERROR;
|
||
|
|
||
|
state->devices[device_id] = std::make_unique<Device>(
|
||
|
num_ports, buffer,
|
||
|
std::make_unique<rpc::Server<lane_size>>(num_ports, buffer));
|
||
|
if (!state->devices[device_id])
|
||
|
return RPC_STATUS_ERROR;
|
||
|
|
||
|
return RPC_STATUS_SUCCESS;
|
||
|
}
|
||
|
|
||
|
rpc_status_t rpc_server_init(uint32_t device_id, uint64_t num_ports,
|
||
|
uint32_t lane_size, rpc_alloc_ty alloc,
|
||
|
void *data) {
|
||
|
if (!state)
|
||
|
return RPC_STATUS_NOT_INITIALIZED;
|
||
|
if (device_id >= state->num_devices)
|
||
|
return RPC_STATUS_OUT_OF_RANGE;
|
||
|
|
||
|
if (!state->devices[device_id]) {
|
||
|
switch (lane_size) {
|
||
|
case 1:
|
||
|
if (rpc_status_t err =
|
||
|
server_init_impl<1>(device_id, num_ports, alloc, data))
|
||
|
return err;
|
||
|
break;
|
||
|
case 32: {
|
||
|
if (rpc_status_t err =
|
||
|
server_init_impl<32>(device_id, num_ports, alloc, data))
|
||
|
return err;
|
||
|
break;
|
||
|
}
|
||
|
case 64:
|
||
|
if (rpc_status_t err =
|
||
|
server_init_impl<64>(device_id, num_ports, alloc, data))
|
||
|
return err;
|
||
|
break;
|
||
|
default:
|
||
|
return RPC_STATUS_INVALID_LANE_SIZE;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return RPC_STATUS_SUCCESS;
|
||
|
}
|
||
|
|
||
|
rpc_status_t rpc_server_shutdown(uint32_t device_id, rpc_free_ty dealloc,
|
||
|
void *data) {
|
||
|
if (!state)
|
||
|
return RPC_STATUS_NOT_INITIALIZED;
|
||
|
if (device_id >= state->num_devices)
|
||
|
return RPC_STATUS_OUT_OF_RANGE;
|
||
|
if (!state->devices[device_id])
|
||
|
return RPC_STATUS_ERROR;
|
||
|
|
||
|
dealloc(state->devices[device_id]->buffer, data);
|
||
|
if (state->devices[device_id])
|
||
|
state->devices[device_id].release();
|
||
|
|
||
|
return RPC_STATUS_SUCCESS;
|
||
|
}
|
||
|
|
||
|
rpc_status_t rpc_handle_server(uint32_t device_id) {
|
||
|
if (!state)
|
||
|
return RPC_STATUS_NOT_INITIALIZED;
|
||
|
if (device_id >= state->num_devices)
|
||
|
return RPC_STATUS_OUT_OF_RANGE;
|
||
|
if (!state->devices[device_id])
|
||
|
return RPC_STATUS_ERROR;
|
||
|
|
||
|
uint32_t index = 0;
|
||
|
for (;;) {
|
||
|
auto &device = *state->devices[device_id];
|
||
|
rpc_status_t status = device.server.handle_server(
|
||
|
device.callbacks, device.callback_data, index);
|
||
|
if (status != RPC_STATUS_CONTINUE)
|
||
|
return status;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
rpc_status_t rpc_register_callback(uint32_t device_id, rpc_opcode_t opcode,
|
||
|
rpc_opcode_callback_ty callback,
|
||
|
void *data) {
|
||
|
if (!state)
|
||
|
return RPC_STATUS_NOT_INITIALIZED;
|
||
|
if (device_id >= state->num_devices)
|
||
|
return RPC_STATUS_OUT_OF_RANGE;
|
||
|
if (!state->devices[device_id])
|
||
|
return RPC_STATUS_ERROR;
|
||
|
|
||
|
state->devices[device_id]->callbacks[opcode] = callback;
|
||
|
state->devices[device_id]->callback_data[opcode] = data;
|
||
|
return RPC_STATUS_SUCCESS;
|
||
|
}
|
||
|
|
||
|
const void *rpc_get_client_buffer(uint32_t device_id) {
|
||
|
if (!state || device_id >= state->num_devices || !state->devices[device_id])
|
||
|
return nullptr;
|
||
|
return &state->devices[device_id]->client;
|
||
|
}
|
||
|
|
||
|
uint64_t rpc_get_client_size() { return sizeof(rpc::Client); }
|
||
|
|
||
|
using ServerPort = std::variant<rpc::Server<1>::Port *, rpc::Server<32>::Port *,
|
||
|
rpc::Server<64>::Port *>;
|
||
|
|
||
|
ServerPort get_port(rpc_port_t ref) {
|
||
|
if (ref.lane_size == 1)
|
||
|
return reinterpret_cast<rpc::Server<1>::Port *>(ref.handle);
|
||
|
else if (ref.lane_size == 32)
|
||
|
return reinterpret_cast<rpc::Server<32>::Port *>(ref.handle);
|
||
|
else if (ref.lane_size == 64)
|
||
|
return reinterpret_cast<rpc::Server<64>::Port *>(ref.handle);
|
||
|
else
|
||
|
__builtin_unreachable();
|
||
|
}
|
||
|
|
||
|
void rpc_send(rpc_port_t ref, rpc_port_callback_ty callback, void *data) {
|
||
|
auto port = get_port(ref);
|
||
|
std::visit(
|
||
|
[=](auto &port) {
|
||
|
port->send([=](rpc::Buffer *buffer) {
|
||
|
callback(reinterpret_cast<rpc_buffer_t *>(buffer), data);
|
||
|
});
|
||
|
},
|
||
|
port);
|
||
|
}
|
||
|
|
||
|
void rpc_send_n(rpc_port_t ref, const void *const *src, uint64_t *size) {
|
||
|
auto port = get_port(ref);
|
||
|
std::visit([=](auto &port) { port->send_n(src, size); }, port);
|
||
|
}
|
||
|
|
||
|
void rpc_recv(rpc_port_t ref, rpc_port_callback_ty callback, void *data) {
|
||
|
auto port = get_port(ref);
|
||
|
std::visit(
|
||
|
[=](auto &port) {
|
||
|
port->recv([=](rpc::Buffer *buffer) {
|
||
|
callback(reinterpret_cast<rpc_buffer_t *>(buffer), data);
|
||
|
});
|
||
|
},
|
||
|
port);
|
||
|
}
|
||
|
|
||
|
void rpc_recv_n(rpc_port_t ref, void **dst, uint64_t *size, rpc_alloc_ty alloc,
|
||
|
void *data) {
|
||
|
auto port = get_port(ref);
|
||
|
auto alloc_fn = [=](uint64_t size) { return alloc(size, data); };
|
||
|
std::visit([=](auto &port) { port->recv_n(dst, size, alloc_fn); }, port);
|
||
|
}
|
||
|
|
||
|
void rpc_recv_and_send(rpc_port_t ref, rpc_port_callback_ty callback,
|
||
|
void *data) {
|
||
|
auto port = get_port(ref);
|
||
|
std::visit(
|
||
|
[=](auto &port) {
|
||
|
port->recv_and_send([=](rpc::Buffer *buffer) {
|
||
|
callback(reinterpret_cast<rpc_buffer_t *>(buffer), data);
|
||
|
});
|
||
|
},
|
||
|
port);
|
||
|
}
|