/* Steve, the jobserver
 * (c) 2025 Michał Górny
 * SPDX-License-Identifier: GPL-2.0-or-later
 *
 * Inspired by CUSE example, nixos-jobserver (draft) and guildmaster:
 * https://github.com/libfuse/libfuse/blob/f58d4c5b0d56116d8870753f6b9d1620ee082709/example/cuse.c
 * https://github.com/RaitoBezarius/nixpkgs/blob/e97220ecf1e8887b949e4e16547bf0334826d076/pkgs/by-name/ni/nixos-jobserver/nixos-jobserver.cpp#L213
 * https://codeberg.org/amonakov/guildmaster/
 */

#define FUSE_USE_VERSION 31

#include <cstddef>
#include <cstdio>
#include <cstdlib>
#include <cerrno>
#include <climits>
#include <csignal>
#include <deque>
#include <functional>
#include <memory>
#include <print>
#include <string>
#include <unordered_map>

#include <getopt.h>
#include <sys/poll.h>
#include <sys/syscall.h>
#include <unistd.h>

#include <event2/event.h>

#include <cuse_lowlevel.h>
#include <fuse.h>
#include <fuse_opt.h>

struct steve_read_waiter {
	fuse_req_t req;
	uint64_t pid;
};

struct steve_poll_waiter {
	fuse_pollhandle *poll_handle;
	uint64_t pid;

	steve_poll_waiter(fuse_pollhandle *new_poll_handle, uint64_t new_pid)
		: poll_handle(new_poll_handle), pid(new_pid) {}

	steve_poll_waiter(const steve_poll_waiter &) = delete;
	steve_poll_waiter& operator=(const steve_poll_waiter &) = delete;

	~steve_poll_waiter() {
		fuse_pollhandle_destroy(poll_handle);
	}
};

struct steve_process {
	int pid_fd{-1};
	ssize_t tokens_held{0};
	std::unique_ptr<struct event, std::function<void(struct event*)>> event;

	~steve_process() {
		if (pid_fd != -1)
			close(pid_fd);
	}
};

struct steve_state {
	bool verbose;
	size_t jobs;
	size_t tokens;
	std::deque<steve_read_waiter> read_waiters;
	std::deque<steve_poll_waiter> poll_waiters;
	std::unordered_map<uint64_t, steve_process> processes;
	struct event_base *evb;

	/* to workaround lack of fuse_buf_free(), keep a global buffer */
	/* https://github.com/libfuse/libfuse/issues/1373 */
	struct fuse_session *session;
	struct fuse_buf buf{};
};

static void steve_give_token(steve_state *state, fuse_req_t req, uint64_t pid)
{
	state->tokens--;
	state->processes[pid].tokens_held++;
	if (state->verbose)
		std::print(stderr, "Giving job token to PID {}, {} left, {} tokens held by process\n",
				pid, state->tokens, state->processes[pid].tokens_held);
	fuse_reply_buf(req, "+", 1);
}

static void steve_wake_waiters(steve_state *state)
{
	while (state->tokens > 0 && !state->read_waiters.empty()) {
		const steve_read_waiter *read_waiter = &state->read_waiters.front();
		steve_give_token(state, read_waiter->req, read_waiter->pid);
		state->read_waiters.pop_front();
	}

	if (state->tokens > 0) {
		for (auto &poll_waiter : state->poll_waiters) {
			if (state->verbose)
				std::print(stderr, "Notifying PID {} about POLLIN, {} tokens left, {} tokens held by process\n",
						poll_waiter.pid, state->tokens, state->processes[poll_waiter.pid].tokens_held);
			fuse_lowlevel_notify_poll(poll_waiter.poll_handle);
		}
		state->poll_waiters.clear();
	}
}

static void steve_handle_pidfd(evutil_socket_t pid_fd, short, void *userdata) {
	steve_state *state = static_cast<steve_state *>(userdata);

	for (auto it = state->processes.begin(); it != state->processes.end(); ++it) {
		if (it->second.pid_fd == pid_fd) {
			state->tokens += it->second.tokens_held;
			if (state->verbose || it->second.tokens_held > 0) {
				std::print(stderr, "Process {} exited while holding {} tokens, "
						"{} tokens available after returning them\n",
						it->first, it->second.tokens_held, state->tokens);
			}
			state->processes.erase(it);
			steve_wake_waiters(state);
			return;
		}
	}

	assert(0 && "pidfd triggered for unknown process");
}

static void steve_init(void *userdata, struct fuse_conn_info *)
{
	steve_state *state = static_cast<steve_state *>(userdata);

	state->tokens = state->jobs;

	std::print(stderr, "steve running on /dev/steve for {} jobs\n", state->jobs);
}

static void steve_destroy(void *userdata)
{
	steve_state *state = static_cast<steve_state *>(userdata);

	state->read_waiters.clear();
	state->poll_waiters.clear();
	state->processes.clear();
}

static void steve_open(fuse_req_t req, struct fuse_file_info *fi)
{
	const struct fuse_ctx *context = fuse_req_ctx(req);
	steve_state *state = static_cast<steve_state *>(fuse_req_userdata(req));

	/* pid is not available in release, so store it here */
	static_assert(sizeof(fi->fh) >= sizeof(context->pid));
	fi->fh = context->pid;

	if (state->verbose) {
		char cmdline[128] = {};

		std::string path = std::format("/proc/{}/cmdline", fi->fh);
		FILE *cmdline_file = fopen(path.c_str(), "r");
		if (cmdline_file) {
			size_t rd = fread(cmdline, 1, sizeof(cmdline) - 1, cmdline_file);
			if (rd > 0) {
				/* replace all NULs with spaces, except for the final one */
				for (size_t i = 0; i < rd - 1; ++i) {
					if (cmdline[i] == 0)
						cmdline[i] = ' ';
				}
				/* ensure a NUL, in case it was truncated */
				cmdline[rd] = 0;
			}
			fclose(cmdline_file);
		}

		if (cmdline[0]) {
			std::print(stderr, "Device open by PID {} ({})\n", fi->fh, cmdline);
		} else
			std::print(stderr, "Device open by PID {} (process name unknown)\n", fi->fh);
	}

	int pid_fd;
	if (state->processes.find(fi->fh) != state->processes.end()) {
		assert(state->processes[fi->fh].pid_fd != -1);
		assert(state->processes[fi->fh].event);
		pid_fd = state->processes[fi->fh].pid_fd;
	} else {
		pid_fd = syscall(SYS_pidfd_open, context->pid, 0);
		if (pid_fd == -1) {
			perror("unable to open pidfd, rejecting to open");
			fuse_reply_err(req, EIO);
			return;
		}

		std::unique_ptr<struct event, std::function<void(struct event*)>>
			pidfd_event{event_new(state->evb, pid_fd, EV_READ|EV_PERSIST, steve_handle_pidfd, state), event_free};
		if (!pidfd_event) {
			std::print(stderr, "unable to allocate event for pidfd");
			close(pid_fd);
			fuse_reply_err(req, EIO);
			return;
		}
		if (event_add(pidfd_event.get(), nullptr) == -1) {
			std::print(stderr, "failed to enable pidfd handler");
			close(pid_fd);
			fuse_reply_err(req, EIO);
			return;
		}

		state->processes[fi->fh].pid_fd = pid_fd;
		state->processes[fi->fh].event = std::move(pidfd_event);
	}

	fuse_reply_open(req, fi);
}

static void steve_release(fuse_req_t req, struct fuse_file_info *fi)
{
	steve_state *state = static_cast<steve_state *>(fuse_req_userdata(req));

	if (state->verbose)
		std::print(stderr, "Device closed by PID {}\n", fi->fh);

	fuse_reply_err(req, 0);
}

static void steve_interrupt(fuse_req_t req, void *userdata)
{
	steve_state *state = static_cast<steve_state *>(userdata);

	fuse_reply_err(req, EINTR);
	for (auto it = state->read_waiters.begin(); it != state->read_waiters.end(); ++it) {
		if (it->req == req) {
			if (state->verbose)
				std::print(stderr, "Passed EINTR to PID {}\n", it->pid);
			state->read_waiters.erase(it);
			break;
		}
	}
}

static void steve_read(
	fuse_req_t req, size_t size, off_t off, struct fuse_file_info *fi)
{
	steve_state *state = static_cast<steve_state *>(fuse_req_userdata(req));

	if (off != 0) {
		fuse_reply_err(req, EIO);
		return;
	}
	if (size == 0) {
		fuse_reply_buf(req, "", 0);
		return;
	}

	/* no need to support reading more than one token at a time */
	if (state->tokens > 0) {
		steve_give_token(state, req, fi->fh);
		return;
	}

	if (fi->flags & O_NONBLOCK) {
		fuse_reply_err(req, EAGAIN);
		return;
	}

	state->read_waiters.emplace_back(steve_read_waiter{req, fi->fh});
	if (state->verbose)
		std::print(stderr, "No free job token for PID {}, waiting, {} tokens held by process\n",
				fi->fh, state->processes[fi->fh].tokens_held);
	fuse_req_interrupt_func(req, steve_interrupt, state);
}

static void steve_write(
	fuse_req_t req, const char *, size_t size, off_t off,
	struct fuse_file_info *fi)
{
	steve_state *state = static_cast<steve_state *>(fuse_req_userdata(req));

	if (off != 0) {
		fuse_reply_err(req, EIO);
		return;
	}
	if (size > SSIZE_MAX) {
		std::print(stderr, "Warning: process {} tried to return more than SSIZE_MAX tokens\n",
				fi->fh);
		fuse_reply_err(req, EFBIG);
		return;
	}

	/* workaround for https://github.com/medek/nasm-rs/issues/44 */
	if (state->processes[fi->fh].tokens_held == 0 && size == 1) {
		std::print(stderr, "Warning: process {} pre-released an unacquired token, please report a bug upstream\n",
				fi->fh);
	} else if (state->processes[fi->fh].tokens_held < static_cast<ssize_t>(size)) {
		std::print(stderr, "Warning: process {} tried to return {} tokens while holding only {} tokens, capping\n",
				fi->fh, size, state->processes[fi->fh].tokens_held);
		if (state->processes[fi->fh].tokens_held < 0)
			size = 0;
		else
			size = state->processes[fi->fh].tokens_held;
	}
	if (size == 0) {
		fuse_reply_err(req, ENOSPC);
		return;
	}

	state->tokens += size;
	state->processes[fi->fh].tokens_held -= size;
	if (state->verbose)
		std::print(stderr, "PID {} returned {} tokens, {} available now, {} tokens held by process\n",
				fi->fh, size, state->tokens, state->processes[fi->fh].tokens_held);
	fuse_reply_write(req, size);

	/* Since we have jobs now, see if anyone's waiting */
	steve_wake_waiters(state);
}

static void steve_poll(
	fuse_req_t req, struct fuse_file_info *fi, struct fuse_pollhandle *ph)
{
	steve_state *state = static_cast<steve_state *>(fuse_req_userdata(req));
	int events = fi->poll_events & (POLLIN | POLLOUT);

	if (state->verbose)
		std::print(stderr, "PID {} requested poll, {} tokens available, {} tokens held by process\n",
				fi->fh, state->tokens, state->processes[fi->fh].tokens_held);

	/* POLLOUT is always possible, POLLIN only if we have any tokens */
	if (state->tokens == 0) {
		state->poll_waiters.emplace_back(ph, fi->fh);
		events &= ~POLLIN;
	}

	fuse_reply_poll(req, events);
}

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
static const struct cuse_lowlevel_ops steve_ops = {
	.init = steve_init,
	.destroy = steve_destroy,
	.open = steve_open,
	.read = steve_read,
	.write = steve_write,
	.release = steve_release,
	.poll = steve_poll,
};
#pragma GCC diagnostic pop

static void steve_handle_sigusr1(evutil_socket_t, short, void *userdata) {
	steve_state *state = static_cast<steve_state *>(userdata);

	std::print(stderr, "steve: currently {} tokens available out of {}\n",
			state->tokens, state->jobs);
	for (auto &it : state->processes) {
		std::print(stderr, "PID {} holds {} tokens\n", it.first, it.second.tokens_held);
	}
}

static void steve_handle_cuse(evutil_socket_t, short, void *userdata) {
	steve_state *state = static_cast<steve_state *>(userdata);

	if (fuse_session_receive_buf(state->session, &state->buf) > 0)
		fuse_session_process_buf(state->session, &state->buf);
}

static constexpr char steve_usage[] =
"usage: {} [options]\n"
"\n"
"options:\n"
"    --help, -h             print this help message\n"
"    --version, -V          print version\n"
"    --jobs=JOBS, -j JOBS   jobs to use (default: nproc)\n"
"    --verbose, -v          enable verbose logging\n"
"    --debug, -d            enable FUSE debug output\n";

static const struct option steve_opts[] = {
	{ "help", no_argument, 0, 'h' },
	{ "version", no_argument, 0, 'V' },
	{ "jobs", required_argument, 0, 'j' },
	{ "verbose", no_argument, 0, 'v' },
	{ "debug", no_argument, 0, 'd' },
	{},
};

struct fd_guard {
	int fd;
	~fd_guard() { close(fd); }
};

int main(int argc, char **argv)
{
	steve_state state{};

	int opt;
	bool debug = false;
	while ((opt = getopt_long(argc, argv, "hVj:vd", steve_opts, nullptr)) != -1) {
		switch (opt) {
			case 'h':
				std::print(steve_usage, argv[0]);
				return 0;
			case 'V':
				std::print("steve {}\n", STEVE_VERSION);
				return 0;
			case 'j':
				{
					char *endptr;
					errno = 0;
					long jobs_arg = strtol(optarg, &endptr, 10);
					if (*endptr || errno == ERANGE || jobs_arg < 0 || jobs_arg > INT_MAX) {
						std::print(stderr, "invalid job number: {}\n", optarg);
						return 1;
					}
					state.jobs = jobs_arg;
				}
				break;
			case 'v':
				state.verbose = true;
				break;
			case 'd':
				debug = true;
				break;
			default:
				std::print(stderr, steve_usage, argv[0]);
				return 1;
		}
	}

	std::unique_ptr<struct event_base, std::function<void(struct event_base*)>>
		evb{event_base_new(), event_base_free};
	if (!evb) {
		std::print(stderr, "failed to initialize libevent\n");
		return 1;
	}
	state.evb = evb.get();

	int cuse_fd = open("/dev/cuse", O_RDWR);
	if (cuse_fd == -1) {
		perror("unable to open /dev/cuse");
		return 1;
	}
	fd_guard cuse_fd_guard{cuse_fd};

	const char *dev_name = "DEVNAME=steve";
	const char *dev_info_argv[] = { dev_name };
	struct cuse_info ci{};
	ci.dev_info_argc = 1;
	ci.dev_info_argv = dev_info_argv;
	if (state.jobs == 0)
		state.jobs = sysconf(_SC_NPROCESSORS_ONLN);

	struct fuse_args args = FUSE_ARGS_INIT(0, nullptr);
	std::unique_ptr<struct fuse_args, std::function<void(struct fuse_args*)>>
		args_ptr{&args, fuse_opt_free_args};
	fuse_opt_add_arg(args_ptr.get(), argv[0]);
	if (debug)
		fuse_opt_add_arg(args_ptr.get(), "-d");

	std::unique_ptr<struct fuse_session, std::function<void(struct fuse_session*)>> session{
		cuse_lowlevel_new(args_ptr.get(), &ci, &steve_ops, &state), fuse_session_destroy};
	if (!session) {
		std::print(stderr, "failed to initialize FUSE");
		return 1;
	}
	state.session = session.get();

	std::unique_ptr<struct event, std::function<void(struct event*)>>
		cuse_event{event_new(evb.get(), cuse_fd, EV_READ|EV_PERSIST, steve_handle_cuse, &state), event_free};
	if (!cuse_event) {
		std::print(stderr, "failed to initialize CUSE handler");
		return 1;
	}
	if (event_add(cuse_event.get(), nullptr) == -1) {
		std::print(stderr, "failed to enable CUSE handler");
		return 1;
	}

	std::unique_ptr<struct event, std::function<void(struct event*)>>
		sigusr1_event{evsignal_new(evb.get(), SIGUSR1, steve_handle_sigusr1, &state), event_free};
	if (!sigusr1_event) {
		std::print(stderr, "failed to initialize SIGUSR1 handler");
		return 1;
	}
	if (event_add(sigusr1_event.get(), nullptr) == -1) {
		std::print(stderr, "failed to enable SIGUSR1 handler");
		return 1;
	}

	std::string mountpoint = std::format("/dev/fd/{}", cuse_fd);
	if (fuse_session_mount(session.get(), mountpoint.c_str()) == -1) {
		std::print(stderr, "failed to mount the filesystem");
		return 1;
	}

	event_base_dispatch(evb.get());
	fuse_session_unmount(session.get());
	return 0;
}
