//===-- LSPClient.cpp - Helper for ClangdLSPServer tests ------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "LSPClient.h" #include "Protocol.h" #include "TestFS.h" #include "Transport.h" #include "support/Logger.h" #include "support/Threading.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Error.h" #include "llvm/Support/JSON.h" #include "llvm/Support/Path.h" #include "llvm/Support/raw_ostream.h" #include "gtest/gtest.h" #include #include #include #include #include #include #include #include #include #include #include #include namespace clang { namespace clangd { llvm::Expected clang::clangd::LSPClient::CallResult::take() { std::unique_lock Lock(Mu); static constexpr size_t TimeoutSecs = 60; if (!clangd::wait(Lock, CV, timeoutSeconds(TimeoutSecs), [this] { return Value.has_value(); })) { ADD_FAILURE() << "No result from call after " << TimeoutSecs << " seconds!"; return llvm::json::Value(nullptr); } auto Res = std::move(*Value); Value.reset(); return Res; } llvm::json::Value LSPClient::CallResult::takeValue() { auto ExpValue = take(); if (!ExpValue) { ADD_FAILURE() << "takeValue(): " << llvm::toString(ExpValue.takeError()); return llvm::json::Value(nullptr); } return std::move(*ExpValue); } void LSPClient::CallResult::set(llvm::Expected V) { std::lock_guard Lock(Mu); if (Value) { ADD_FAILURE() << "Multiple replies"; llvm::consumeError(V.takeError()); return; } Value = std::move(V); CV.notify_all(); } LSPClient::CallResult::~CallResult() { if (Value && !*Value) { ADD_FAILURE() << llvm::toString(Value->takeError()); } } static void logBody(llvm::StringRef Method, llvm::json::Value V, bool Send) { // We invert <<< and >>> as the combined log is from the server's viewpoint. vlog("{0} {1}: {2:2}", Send ? "<<<" : ">>>", Method, V); } class LSPClient::TransportImpl : public Transport { public: std::pair addCallSlot() { std::lock_guard Lock(Mu); unsigned ID = CallResults.size(); CallResults.emplace_back(); return {ID, &CallResults.back()}; } // A null action causes the transport to shut down. void enqueue(std::function Action) { std::lock_guard Lock(Mu); Actions.push(std::move(Action)); CV.notify_all(); } std::vector takeNotifications(llvm::StringRef Method) { std::vector Result; { std::lock_guard Lock(Mu); std::swap(Result, Notifications[Method]); } return Result; } private: void reply(llvm::json::Value ID, llvm::Expected V) override { if (V) // Nothing additional to log for error. logBody("reply", *V, /*Send=*/false); std::lock_guard Lock(Mu); if (auto I = ID.getAsInteger()) { if (*I >= 0 && *I < static_cast(CallResults.size())) { CallResults[*I].set(std::move(V)); return; } } ADD_FAILURE() << "Invalid reply to ID " << ID; llvm::consumeError(std::move(V).takeError()); } void notify(llvm::StringRef Method, llvm::json::Value V) override { logBody(Method, V, /*Send=*/false); std::lock_guard Lock(Mu); Notifications[Method].push_back(std::move(V)); } void call(llvm::StringRef Method, llvm::json::Value Params, llvm::json::Value ID) override { logBody(Method, Params, /*Send=*/false); ADD_FAILURE() << "Unexpected server->client call " << Method; } llvm::Error loop(MessageHandler &H) override { std::unique_lock Lock(Mu); while (true) { CV.wait(Lock, [&] { return !Actions.empty(); }); if (!Actions.front()) // Stop! return llvm::Error::success(); auto Action = std::move(Actions.front()); Actions.pop(); Lock.unlock(); Action(H); Lock.lock(); } } std::mutex Mu; std::deque CallResults; std::queue> Actions; std::condition_variable CV; llvm::StringMap> Notifications; }; LSPClient::LSPClient() : T(std::make_unique()) {} LSPClient::~LSPClient() = default; LSPClient::CallResult &LSPClient::call(llvm::StringRef Method, llvm::json::Value Params) { auto Slot = T->addCallSlot(); T->enqueue([ID(Slot.first), Method(Method.str()), Params(std::move(Params))](Transport::MessageHandler &H) { logBody(Method, Params, /*Send=*/true); H.onCall(Method, std::move(Params), ID); }); return *Slot.second; } void LSPClient::notify(llvm::StringRef Method, llvm::json::Value Params) { T->enqueue([Method(Method.str()), Params(std::move(Params))](Transport::MessageHandler &H) { logBody(Method, Params, /*Send=*/true); H.onNotify(Method, std::move(Params)); }); } std::vector LSPClient::takeNotifications(llvm::StringRef Method) { return T->takeNotifications(Method); } void LSPClient::stop() { T->enqueue(nullptr); } Transport &LSPClient::transport() { return *T; } using Obj = llvm::json::Object; llvm::json::Value LSPClient::uri(llvm::StringRef Path) { std::string Storage; if (!llvm::sys::path::is_absolute(Path)) Path = Storage = testPath(Path); return toJSON(URIForFile::canonicalize(Path, Path)); } llvm::json::Value LSPClient::documentID(llvm::StringRef Path) { return Obj{{"uri", uri(Path)}}; } void LSPClient::didOpen(llvm::StringRef Path, llvm::StringRef Content) { notify( "textDocument/didOpen", Obj{{"textDocument", Obj{{"uri", uri(Path)}, {"text", Content}, {"languageId", "cpp"}}}}); } void LSPClient::didChange(llvm::StringRef Path, llvm::StringRef Content) { notify("textDocument/didChange", Obj{{"textDocument", documentID(Path)}, {"contentChanges", llvm::json::Array{Obj{{"text", Content}}}}}); } void LSPClient::didClose(llvm::StringRef Path) { notify("textDocument/didClose", Obj{{"textDocument", documentID(Path)}}); } void LSPClient::sync() { call("sync", nullptr).takeValue(); } std::optional> LSPClient::diagnostics(llvm::StringRef Path) { sync(); auto Notifications = takeNotifications("textDocument/publishDiagnostics"); for (const auto &Notification : llvm::reverse(Notifications)) { if (const auto *PubDiagsParams = Notification.getAsObject()) { auto U = PubDiagsParams->getString("uri"); auto *D = PubDiagsParams->getArray("diagnostics"); if (!U || !D) { ADD_FAILURE() << "Bad PublishDiagnosticsParams: " << PubDiagsParams; continue; } if (*U == uri(Path)) return std::vector(D->begin(), D->end()); } } return {}; } } // namespace clangd } // namespace clang