From 77c1f647cb759d320121cf6eef7a72e842abfef4 Mon Sep 17 00:00:00 2001 From: Weiyi Wang Date: Wed, 12 Sep 2018 12:22:48 -0400 Subject: [PATCH 1/2] web_service: stop using std::future + callback style async --- src/citra/citra.cpp | 3 + src/citra_qt/configuration/configure_web.cpp | 13 +- src/citra_qt/configuration/configure_web.h | 7 +- src/citra_qt/main.cpp | 6 +- src/citra_qt/multiplayer/lobby.cpp | 8 +- src/citra_qt/multiplayer/lobby.h | 8 +- src/common/CMakeLists.txt | 2 + src/common/announce_multiplayer_room.h | 21 +- src/common/detached_tasks.cpp | 41 ++ src/common/detached_tasks.h | 39 ++ src/core/announce_multiplayer_session.cpp | 20 +- src/core/announce_multiplayer_session.h | 2 +- src/core/telemetry_session.cpp | 16 +- src/core/telemetry_session.h | 5 +- src/dedicated_room/citra-room.cpp | 3 + src/web_service/announce_room_json.cpp | 26 +- src/web_service/announce_room_json.h | 13 +- src/web_service/telemetry_json.cpp | 8 +- src/web_service/telemetry_json.h | 9 +- src/web_service/verify_login.cpp | 30 +- src/web_service/verify_login.h | 8 +- src/web_service/web_backend.cpp | 374 +++++-------------- src/web_service/web_backend.h | 124 +++--- 23 files changed, 329 insertions(+), 457 deletions(-) create mode 100644 src/common/detached_tasks.cpp create mode 100644 src/common/detached_tasks.h diff --git a/src/citra/citra.cpp b/src/citra/citra.cpp index 29ce5f898..1b13d8589 100644 --- a/src/citra/citra.cpp +++ b/src/citra/citra.cpp @@ -26,6 +26,7 @@ #include "citra/config.h" #include "citra/emu_window/emu_window_sdl2.h" #include "common/common_paths.h" +#include "common/detached_tasks.h" #include "common/file_util.h" #include "common/logging/backend.h" #include "common/logging/filter.h" @@ -117,6 +118,7 @@ static void OnMessageReceived(const Network::ChatEntry& msg) { /// Application entry point int main(int argc, char** argv) { + Common::DetachedTasks detached_tasks; Config config; int option_index = 0; bool use_gdbstub = Settings::values.use_gdbstub; @@ -339,5 +341,6 @@ int main(int argc, char** argv) { Core::Movie::GetInstance().Shutdown(); + detached_tasks.WaitForAllTasks(); return 0; } diff --git a/src/citra_qt/configuration/configure_web.cpp b/src/citra_qt/configuration/configure_web.cpp index ff66e45d1..cf1daf76b 100644 --- a/src/citra_qt/configuration/configure_web.cpp +++ b/src/citra_qt/configuration/configure_web.cpp @@ -4,6 +4,7 @@ #include #include +#include #include "citra_qt/configuration/configure_web.h" #include "citra_qt/ui_settings.h" #include "core/settings.h" @@ -16,7 +17,7 @@ ConfigureWeb::ConfigureWeb(QWidget* parent) connect(ui->button_regenerate_telemetry_id, &QPushButton::clicked, this, &ConfigureWeb::RefreshTelemetryID); connect(ui->button_verify_login, &QPushButton::clicked, this, &ConfigureWeb::VerifyLogin); - connect(this, &ConfigureWeb::LoginVerified, this, &ConfigureWeb::OnLoginVerified); + connect(&verify_watcher, &QFutureWatcher::finished, this, &ConfigureWeb::OnLoginVerified); #ifndef USE_DISCORD_PRESENCE ui->discord_group->setVisible(false); @@ -89,17 +90,19 @@ void ConfigureWeb::OnLoginChanged() { } void ConfigureWeb::VerifyLogin() { - verified = - Core::VerifyLogin(ui->edit_username->text().toStdString(), - ui->edit_token->text().toStdString(), [&]() { emit LoginVerified(); }); ui->button_verify_login->setDisabled(true); ui->button_verify_login->setText(tr("Verifying")); + verify_watcher.setFuture( + QtConcurrent::run([this, username = ui->edit_username->text().toStdString(), + token = ui->edit_token->text().toStdString()]() { + return Core::VerifyLogin(username, token); + })); } void ConfigureWeb::OnLoginVerified() { ui->button_verify_login->setEnabled(true); ui->button_verify_login->setText(tr("Verify")); - if (verified.get()) { + if (verify_watcher.result()) { user_verified = true; ui->label_username_verified->setPixmap(QIcon::fromTheme("checked").pixmap(16)); ui->label_token_verified->setPixmap(QIcon::fromTheme("checked").pixmap(16)); diff --git a/src/citra_qt/configuration/configure_web.h b/src/citra_qt/configuration/configure_web.h index b8e71ffdd..7741ab95d 100644 --- a/src/citra_qt/configuration/configure_web.h +++ b/src/citra_qt/configuration/configure_web.h @@ -4,8 +4,8 @@ #pragma once -#include #include +#include #include namespace Ui { @@ -28,14 +28,11 @@ public slots: void VerifyLogin(); void OnLoginVerified(); -signals: - void LoginVerified(); - private: void setConfiguration(); bool user_verified = true; - std::future verified; + QFutureWatcher verify_watcher; std::unique_ptr ui; }; diff --git a/src/citra_qt/main.cpp b/src/citra_qt/main.cpp index 3f0388d0a..f7fdc16c7 100644 --- a/src/citra_qt/main.cpp +++ b/src/citra_qt/main.cpp @@ -43,6 +43,7 @@ #include "citra_qt/updater/updater.h" #include "citra_qt/util/clickable_label.h" #include "common/common_paths.h" +#include "common/detached_tasks.h" #include "common/logging/backend.h" #include "common/logging/filter.h" #include "common/logging/log.h" @@ -1660,6 +1661,7 @@ void GMainWindow::SetDiscordEnabled(bool state) { #endif int main(int argc, char* argv[]) { + Common::DetachedTasks detached_tasks; MicroProfileOnThreadCreate("Frontend"); SCOPE_EXIT({ MicroProfileShutdown(); }); @@ -1685,5 +1687,7 @@ int main(int argc, char* argv[]) { Frontend::RegisterSoftwareKeyboard(std::make_shared(main_window)); main_window.show(); - return app.exec(); + int result = app.exec(); + detached_tasks.WaitForAllTasks(); + return result; } diff --git a/src/citra_qt/multiplayer/lobby.cpp b/src/citra_qt/multiplayer/lobby.cpp index c40c00375..8c262bc5f 100644 --- a/src/citra_qt/multiplayer/lobby.cpp +++ b/src/citra_qt/multiplayer/lobby.cpp @@ -63,7 +63,8 @@ Lobby::Lobby(QWidget* parent, QStandardItemModel* list, connect(ui->room_list, &QTreeView::clicked, this, &Lobby::OnExpandRoom); // Actions - connect(this, &Lobby::LobbyRefreshed, this, &Lobby::OnRefreshLobby); + connect(&room_list_watcher, &QFutureWatcher::finished, this, + &Lobby::OnRefreshLobby); // manually start a refresh when the window is opening // TODO(jroweboy): if this refresh is slow for people with bad internet, then don't do it as @@ -149,16 +150,17 @@ void Lobby::ResetModel() { void Lobby::RefreshLobby() { if (auto session = announce_multiplayer_session.lock()) { ResetModel(); - room_list_future = session->GetRoomList([&]() { emit LobbyRefreshed(); }); ui->refresh_list->setEnabled(false); ui->refresh_list->setText(tr("Refreshing")); + room_list_watcher.setFuture( + QtConcurrent::run([session]() { return session->GetRoomList(); })); } else { // TODO(jroweboy): Display an error box about announce couldn't be started } } void Lobby::OnRefreshLobby() { - AnnounceMultiplayerRoom::RoomList new_room_list = room_list_future.get(); + AnnounceMultiplayerRoom::RoomList new_room_list = room_list_watcher.result(); for (auto room : new_room_list) { // find the icon for the game if this person owns that game. QPixmap smdh_icon; diff --git a/src/citra_qt/multiplayer/lobby.h b/src/citra_qt/multiplayer/lobby.h index e0f50a26f..f41bd0ceb 100644 --- a/src/citra_qt/multiplayer/lobby.h +++ b/src/citra_qt/multiplayer/lobby.h @@ -4,7 +4,6 @@ #pragma once -#include #include #include #include @@ -61,11 +60,6 @@ private slots: void OnJoinRoom(const QModelIndex&); signals: - /** - * Signalled when the latest lobby data is retrieved. - */ - void LobbyRefreshed(); - void StateChanged(const Network::RoomMember::State&); private: @@ -84,7 +78,7 @@ private: QStandardItemModel* game_list; LobbyFilterProxyModel* proxy; - std::future room_list_future; + QFutureWatcher room_list_watcher; std::weak_ptr announce_multiplayer_session; std::unique_ptr ui; QFutureWatcher* watcher; diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt index dfc0ea0c1..956315efb 100644 --- a/src/common/CMakeLists.txt +++ b/src/common/CMakeLists.txt @@ -42,6 +42,8 @@ add_library(common STATIC alignment.h announce_multiplayer_room.h assert.h + detached_tasks.cpp + detached_tasks.h bit_field.h bit_set.h chunk_file.h diff --git a/src/common/announce_multiplayer_room.h b/src/common/announce_multiplayer_room.h index 18b4696ba..811f78d8c 100644 --- a/src/common/announce_multiplayer_room.h +++ b/src/common/announce_multiplayer_room.h @@ -6,7 +6,6 @@ #include #include -#include #include #include #include "common/common_types.h" @@ -90,7 +89,7 @@ public: * Send the data to the announce service * @result The result of the announce attempt */ - virtual std::future Announce() = 0; + virtual Common::WebResult Announce() = 0; /** * Empties the stored players @@ -99,11 +98,9 @@ public: /** * Get the room information from the announce service - * @param func a function that gets exectued when the get finished. - * Can be used as a callback * @result A list of all rooms the announce service has */ - virtual std::future GetRoomList(std::function func) = 0; + virtual RoomList GetRoomList() = 0; /** * Sends a delete message to the announce service @@ -124,18 +121,12 @@ public: const u64 /*preferred_game_id*/) override {} void AddPlayer(const std::string& /*nickname*/, const MacAddress& /*mac_address*/, const u64 /*game_id*/, const std::string& /*game_name*/) override {} - std::future Announce() override { - return std::async(std::launch::deferred, []() { - return Common::WebResult{Common::WebResult::Code::NoWebservice, - "WebService is missing"}; - }); + Common::WebResult Announce() override { + return Common::WebResult{Common::WebResult::Code::NoWebservice, "WebService is missing"}; } void ClearPlayers() override {} - std::future GetRoomList(std::function func) override { - return std::async(std::launch::deferred, [func]() { - func(); - return RoomList{}; - }); + RoomList GetRoomList() override { + return RoomList{}; } void Delete() override {} diff --git a/src/common/detached_tasks.cpp b/src/common/detached_tasks.cpp new file mode 100644 index 000000000..a347d9e02 --- /dev/null +++ b/src/common/detached_tasks.cpp @@ -0,0 +1,41 @@ +// Copyright 2018 Citra Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#include +#include "common/assert.h" +#include "common/detached_tasks.h" + +namespace Common { + +DetachedTasks* DetachedTasks::instance = nullptr; + +DetachedTasks::DetachedTasks() { + ASSERT(instance == nullptr); + instance = this; +} + +void DetachedTasks::WaitForAllTasks() { + std::unique_lock lock(mutex); + cv.wait(lock, [this]() { return count == 0; }); +} + +DetachedTasks::~DetachedTasks() { + std::unique_lock lock(mutex); + ASSERT(count == 0); + instance = nullptr; +} + +void DetachedTasks::AddTask(std::function task) { + std::unique_lock lock(instance->mutex); + ++instance->count; + std::thread([task{std::move(task)}]() { + task(); + std::unique_lock lock(instance->mutex); + --instance->count; + std::notify_all_at_thread_exit(instance->cv, std::move(lock)); + }) + .detach(); +} + +} // namespace Common diff --git a/src/common/detached_tasks.h b/src/common/detached_tasks.h new file mode 100644 index 000000000..eae27788d --- /dev/null +++ b/src/common/detached_tasks.h @@ -0,0 +1,39 @@ +// Copyright 2018 Citra Emulator Project +// Licensed under GPLv2 or any later version +// Refer to the license.txt file included. + +#pragma once +#include +#include + +namespace Common { + +/** + * A background manager which ensures that all detached task is finished before program exits. + * + * Some tasks, telemetry submission for example, prefer executing asynchronously and don't care + * about the result. These tasks are suitable for std::thread::detach(). However, this is unsafe if + * the task is launched just before the program exits (which is a common case for telemetry), so we + * need to block on these tasks on program exit. + * + * To make detached task safe, a single DetachedTasks object should be placed in the main(), and + * call WaitForAllTasks() after all program execution but before global/static variable destruction. + * Any potentially unsafe detached task should be executed via DetachedTasks::AddTask. + */ +class DetachedTasks { +public: + DetachedTasks(); + ~DetachedTasks(); + void WaitForAllTasks(); + + static void AddTask(std::function task); + +private: + static DetachedTasks* instance; + + std::condition_variable cv; + std::mutex mutex; + int count = 0; +}; + +} // namespace Common diff --git a/src/core/announce_multiplayer_session.cpp b/src/core/announce_multiplayer_session.cpp index efbe6cc5c..69dbf6fce 100644 --- a/src/core/announce_multiplayer_session.cpp +++ b/src/core/announce_multiplayer_session.cpp @@ -21,7 +21,7 @@ static constexpr std::chrono::seconds announce_time_interval(15); AnnounceMultiplayerSession::AnnounceMultiplayerSession() { #ifdef ENABLE_WEB_SERVICE - backend = std::make_unique(Settings::values.web_api_url + "/lobby", + backend = std::make_unique(Settings::values.web_api_url, Settings::values.citra_username, Settings::values.citra_token); #else @@ -87,22 +87,18 @@ void AnnounceMultiplayerSession::AnnounceMultiplayerLoop() { backend->AddPlayer(member.nickname, member.mac_address, member.game_info.id, member.game_info.name); } - future = backend->Announce(); - if (future.valid()) { - Common::WebResult result = future.get(); - if (result.result_code != Common::WebResult::Code::Success) { - std::lock_guard lock(callback_mutex); - for (auto callback : error_callbacks) { - (*callback)(result); - } + Common::WebResult result = backend->Announce(); + if (result.result_code != Common::WebResult::Code::Success) { + std::lock_guard lock(callback_mutex); + for (auto callback : error_callbacks) { + (*callback)(result); } } } } -std::future AnnounceMultiplayerSession::GetRoomList( - std::function func) { - return backend->GetRoomList(func); +AnnounceMultiplayerRoom::RoomList AnnounceMultiplayerSession::GetRoomList() { + return backend->GetRoomList(); } } // namespace Core diff --git a/src/core/announce_multiplayer_session.h b/src/core/announce_multiplayer_session.h index 0ea357e3a..b9ba4c48a 100644 --- a/src/core/announce_multiplayer_session.h +++ b/src/core/announce_multiplayer_session.h @@ -54,7 +54,7 @@ public: * @param func A function that gets executed when the async get finished, e.g. a signal * @return a list of rooms received from the web service */ - std::future GetRoomList(std::function func); + AnnounceMultiplayerRoom::RoomList GetRoomList(); private: Common::Event shutdown_event; diff --git a/src/core/telemetry_session.cpp b/src/core/telemetry_session.cpp index 04ab300f8..6519b200c 100644 --- a/src/core/telemetry_session.cpp +++ b/src/core/telemetry_session.cpp @@ -80,24 +80,20 @@ u64 RegenerateTelemetryId() { return new_telemetry_id; } -std::future VerifyLogin(std::string username, std::string token, std::function func) { +bool VerifyLogin(std::string username, std::string token) { #ifdef ENABLE_WEB_SERVICE - return WebService::VerifyLogin(username, token, Settings::values.web_api_url + "/profile", - func); + return WebService::VerifyLogin(Settings::values.web_api_url, username, token); #else - return std::async(std::launch::async, [func{std::move(func)}]() { - func(); - return false; - }); + return false; #endif } TelemetrySession::TelemetrySession() { #ifdef ENABLE_WEB_SERVICE if (Settings::values.enable_telemetry) { - backend = std::make_unique( - Settings::values.web_api_url + "/telemetry", Settings::values.citra_username, - Settings::values.citra_token); + backend = std::make_unique(Settings::values.web_api_url, + Settings::values.citra_username, + Settings::values.citra_token); } else { backend = std::make_unique(); } diff --git a/src/core/telemetry_session.h b/src/core/telemetry_session.h index 550c6ea2d..127b6fe5e 100644 --- a/src/core/telemetry_session.h +++ b/src/core/telemetry_session.h @@ -4,7 +4,6 @@ #pragma once -#include #include #include "common/telemetry.h" @@ -31,6 +30,8 @@ public: field_collection.AddField(type, name, std::move(value)); } + static void FinalizeAsyncJob(); + private: Telemetry::FieldCollection field_collection; ///< Tracks all added fields for the session std::unique_ptr backend; ///< Backend interface that logs fields @@ -55,6 +56,6 @@ u64 RegenerateTelemetryId(); * @param func A function that gets exectued when the verification is finished * @returns Future with bool indicating whether the verification succeeded */ -std::future VerifyLogin(std::string username, std::string token, std::function func); +bool VerifyLogin(std::string username, std::string token); } // namespace Core diff --git a/src/dedicated_room/citra-room.cpp b/src/dedicated_room/citra-room.cpp index 9d4dc2a92..7765336a2 100644 --- a/src/dedicated_room/citra-room.cpp +++ b/src/dedicated_room/citra-room.cpp @@ -25,6 +25,7 @@ #endif #include "common/common_types.h" +#include "common/detached_tasks.h" #include "common/scm_rev.h" #include "core/announce_multiplayer_session.h" #include "core/core.h" @@ -54,6 +55,7 @@ static void PrintVersion() { /// Application entry point int main(int argc, char** argv) { + Common::DetachedTasks detached_tasks; int option_index = 0; char* endarg; @@ -204,5 +206,6 @@ int main(int argc, char** argv) { room->Destroy(); } Network::Shutdown(); + detached_tasks.WaitForAllTasks(); return 0; } diff --git a/src/web_service/announce_room_json.cpp b/src/web_service/announce_room_json.cpp index b53e084df..6a6512491 100644 --- a/src/web_service/announce_room_json.cpp +++ b/src/web_service/announce_room_json.cpp @@ -3,6 +3,7 @@ // Refer to the license.txt file included. #include +#include "common/detached_tasks.h" #include "common/logging/log.h" #include "web_service/announce_room_json.h" #include "web_service/json.h" @@ -82,30 +83,31 @@ void RoomJson::AddPlayer(const std::string& nickname, room.members.push_back(member); } -std::future RoomJson::Announce() { +Common::WebResult RoomJson::Announce() { nlohmann::json json = room; - return PostJson(endpoint_url, json.dump(), false); + return client.PostJson("/lobby", json.dump(), false); } void RoomJson::ClearPlayers() { room.members.clear(); } -std::future RoomJson::GetRoomList(std::function func) { - auto DeSerialize = [func](const std::string& reply) -> AnnounceMultiplayerRoom::RoomList { - nlohmann::json json = nlohmann::json::parse(reply); - AnnounceMultiplayerRoom::RoomList room_list = - json.at("rooms").get(); - func(); - return room_list; - }; - return GetJson(DeSerialize, endpoint_url, true); +AnnounceMultiplayerRoom::RoomList RoomJson::GetRoomList() { + auto reply = client.GetJson("/lobby", true).returned_data; + if (reply.empty()) { + return {}; + } + return nlohmann::json::parse(reply).at("rooms").get(); } void RoomJson::Delete() { nlohmann::json json; json["id"] = room.UID; - DeleteJson(endpoint_url, json.dump()); + Common::DetachedTasks::AddTask( + [host{this->host}, username{this->username}, token{this->token}, content{json.dump()}]() { + // create a new client here because the this->client might be destroyed. + Client{host, username, token}.DeleteJson("/lobby", content, false); + }); } } // namespace WebService diff --git a/src/web_service/announce_room_json.h b/src/web_service/announce_room_json.h index 85550e838..735605213 100644 --- a/src/web_service/announce_room_json.h +++ b/src/web_service/announce_room_json.h @@ -5,9 +5,9 @@ #pragma once #include -#include #include #include "common/announce_multiplayer_room.h" +#include "web_service/web_backend.h" namespace WebService { @@ -17,8 +17,8 @@ namespace WebService { */ class RoomJson : public AnnounceMultiplayerRoom::Backend { public: - RoomJson(const std::string& endpoint_url, const std::string& username, const std::string& token) - : endpoint_url(endpoint_url), username(username), token(token) {} + RoomJson(const std::string& host, const std::string& username, const std::string& token) + : client(host, username, token), host(host), username(username), token(token) {} ~RoomJson() = default; void SetRoomInformation(const std::string& uid, const std::string& name, const u16 port, const u32 max_player, const u32 net_version, const bool has_password, @@ -27,14 +27,15 @@ public: void AddPlayer(const std::string& nickname, const AnnounceMultiplayerRoom::MacAddress& mac_address, const u64 game_id, const std::string& game_name) override; - std::future Announce() override; + Common::WebResult Announce() override; void ClearPlayers() override; - std::future GetRoomList(std::function func) override; + AnnounceMultiplayerRoom::RoomList GetRoomList() override; void Delete() override; private: AnnounceMultiplayerRoom::Room room; - std::string endpoint_url; + Client client; + std::string host; std::string username; std::string token; }; diff --git a/src/web_service/telemetry_json.cpp b/src/web_service/telemetry_json.cpp index 0d7ff1c21..a0b7f9c4e 100644 --- a/src/web_service/telemetry_json.cpp +++ b/src/web_service/telemetry_json.cpp @@ -2,7 +2,9 @@ // Licensed under GPLv2 or any later version // Refer to the license.txt file included. +#include #include "common/assert.h" +#include "common/detached_tasks.h" #include "web_service/telemetry_json.h" #include "web_service/web_backend.h" @@ -81,8 +83,12 @@ void TelemetryJson::Complete() { SerializeSection(Telemetry::FieldType::UserConfig, "UserConfig"); SerializeSection(Telemetry::FieldType::UserSystem, "UserSystem"); + auto content = TopSection().dump(); // Send the telemetry async but don't handle the errors since they were written to the log - future = PostJson(endpoint_url, TopSection().dump(), true); + Common::DetachedTasks::AddTask( + [host{this->host}, username{this->username}, token{this->token}, content]() { + Client{host, username, token}.PostJson("/telemetry", content, true); + }); } } // namespace WebService diff --git a/src/web_service/telemetry_json.h b/src/web_service/telemetry_json.h index 27f71b6ba..4335ade59 100644 --- a/src/web_service/telemetry_json.h +++ b/src/web_service/telemetry_json.h @@ -5,7 +5,6 @@ #pragma once #include -#include #include #include "common/announce_multiplayer_room.h" #include "common/telemetry.h" @@ -19,9 +18,8 @@ namespace WebService { */ class TelemetryJson : public Telemetry::VisitorInterface { public: - TelemetryJson(const std::string& endpoint_url, const std::string& username, - const std::string& token) - : endpoint_url(endpoint_url), username(username), token(token) {} + TelemetryJson(const std::string& host, const std::string& username, const std::string& token) + : host(host), username(username), token(token) {} ~TelemetryJson() = default; void Visit(const Telemetry::Field& field) override; @@ -53,10 +51,9 @@ private: nlohmann::json output; std::array sections; - std::string endpoint_url; + std::string host; std::string username; std::string token; - std::future future; }; } // namespace WebService diff --git a/src/web_service/verify_login.cpp b/src/web_service/verify_login.cpp index f2e9615e8..02e1b74f3 100644 --- a/src/web_service/verify_login.cpp +++ b/src/web_service/verify_login.cpp @@ -8,26 +8,20 @@ namespace WebService { -std::future VerifyLogin(std::string& username, std::string& token, - const std::string& endpoint_url, std::function func) { - auto get_func = [func, username](const std::string& reply) -> bool { - func(); +bool VerifyLogin(const std::string& host, const std::string& username, const std::string& token) { + Client client(host, username, token); + auto reply = client.GetJson("/profile", false).returned_data; + if (reply.empty()) { + return false; + } + nlohmann::json json = nlohmann::json::parse(reply); + const auto iter = json.find("username"); - if (reply.empty()) { - return false; - } + if (iter == json.end()) { + return username.empty(); + } - nlohmann::json json = nlohmann::json::parse(reply); - const auto iter = json.find("username"); - - if (iter == json.end()) { - return username.empty(); - } - - return username == *iter; - }; - UpdateCoreJWT(true, username, token); - return GetJson(get_func, endpoint_url, false); + return username == *iter; } } // namespace WebService diff --git a/src/web_service/verify_login.h b/src/web_service/verify_login.h index 303f5dbbc..93eb4036f 100644 --- a/src/web_service/verify_login.h +++ b/src/web_service/verify_login.h @@ -12,13 +12,11 @@ namespace WebService { /** * Checks if username and token is valid + * @param host the web API URL * @param username Citra username to use for authentication. * @param token Citra token to use for authentication. - * @param endpoint_url URL of the services.citra-emu.org endpoint. - * @param func A function that gets exectued when the verification is finished - * @returns Future with bool indicating whether the verification succeeded + * @returns a bool indicating whether the verification succeeded */ -std::future VerifyLogin(std::string& username, std::string& token, - const std::string& endpoint_url, std::function func); +bool VerifyLogin(const std::string& host, const std::string& username, const std::string& token); } // namespace WebService diff --git a/src/web_service/web_backend.cpp b/src/web_service/web_backend.cpp index 3659fe4ff..e935d8b15 100644 --- a/src/web_service/web_backend.cpp +++ b/src/web_service/web_backend.cpp @@ -20,334 +20,128 @@ constexpr int HTTPS_PORT = 443; constexpr int TIMEOUT_SECONDS = 30; -std::string UpdateCoreJWT(bool force_new_token, const std::string& username, - const std::string& token) { - static std::string jwt; - if (jwt.empty() || force_new_token) { - if (!username.empty() && !token.empty()) { - std::future future = - PostJson(Settings::values.web_api_url + "/jwt/internal", username, token); - jwt = future.get().returned_data; - } - } - return jwt; -} +Client::JWTCache Client::jwt_cache{}; -std::unique_ptr GetClientFor(const LUrlParser::clParseURL& parsedUrl) { - namespace hl = httplib; - - int port; - - std::unique_ptr cli; - - if (parsedUrl.m_Scheme == "http") { - if (!parsedUrl.GetPort(&port)) { - port = HTTP_PORT; - } - return std::make_unique(parsedUrl.m_Host.c_str(), port, TIMEOUT_SECONDS); - } else if (parsedUrl.m_Scheme == "https") { - if (!parsedUrl.GetPort(&port)) { - port = HTTPS_PORT; - } - return std::make_unique(parsedUrl.m_Host.c_str(), port, TIMEOUT_SECONDS); - } else { - LOG_ERROR(WebService, "Bad URL scheme {}", parsedUrl.m_Scheme); - return nullptr; +Client::Client(const std::string& host, const std::string& username, const std::string& token) + : host(host), username(username), token(token) { + if (username == jwt_cache.username && token == jwt_cache.token) { + jwt = jwt_cache.jwt; } } -static Common::WebResult PostJsonAsyncFn(const std::string& url, - const LUrlParser::clParseURL& parsed_url, - const httplib::Headers& params, const std::string& data, - bool is_jwt_requested) { - static bool is_first_attempt = true; - - namespace hl = httplib; - std::unique_ptr cli = GetClientFor(parsed_url); - +Common::WebResult Client::GenericJson(const std::string& method, const std::string& path, + const std::string& data, const std::string& jwt, + const std::string& username, const std::string& token) { if (cli == nullptr) { - return Common::WebResult{Common::WebResult::Code::InvalidURL, "URL is invalid"}; + auto parsedUrl = LUrlParser::clParseURL::ParseURL(host); + int port; + if (parsedUrl.m_Scheme == "http") { + if (!parsedUrl.GetPort(&port)) { + port = HTTP_PORT; + } + cli = + std::make_unique(parsedUrl.m_Host.c_str(), port, TIMEOUT_SECONDS); + } else if (parsedUrl.m_Scheme == "https") { + if (!parsedUrl.GetPort(&port)) { + port = HTTPS_PORT; + } + cli = std::make_unique(parsedUrl.m_Host.c_str(), port, + TIMEOUT_SECONDS); + } else { + LOG_ERROR(WebService, "Bad URL scheme {}", parsedUrl.m_Scheme); + return Common::WebResult{Common::WebResult::Code::InvalidURL, "Bad URL scheme"}; + } + } + if (cli == nullptr) { + LOG_ERROR(WebService, "Invalid URL {}", host + path); + return Common::WebResult{Common::WebResult::Code::InvalidURL, "Invalid URL"}; } - hl::Request request; - request.method = "POST"; - request.path = "/" + parsed_url.m_Path; + httplib::Headers params; + if (!jwt.empty()) { + params = { + {std::string("Authorization"), fmt::format("Bearer {}", jwt)}, + }; + } else if (!username.empty()) { + params = { + {std::string("x-username"), username}, + {std::string("x-token"), token}, + }; + } + + params.emplace(std::string("api-version"), std::string(API_VERSION)); + if (method != "GET") { + params.emplace(std::string("Content-Type"), std::string("application/json")); + }; + + httplib::Request request; + request.method = method; + request.path = path; request.headers = params; request.body = data; - hl::Response response; + httplib::Response response; if (!cli->send(request, response)) { - LOG_ERROR(WebService, "POST to {} returned null", url); + LOG_ERROR(WebService, "{} to {} returned null", method, host + path); return Common::WebResult{Common::WebResult::Code::LibError, "Null response"}; } if (response.status >= 400) { - LOG_ERROR(WebService, "POST to {} returned error status code: {}", url, response.status); - if (response.status == 401 && !is_jwt_requested && is_first_attempt) { - LOG_WARNING(WebService, "Requesting new JWT"); - UpdateCoreJWT(true, Settings::values.citra_username, Settings::values.citra_token); - is_first_attempt = false; - PostJsonAsyncFn(url, parsed_url, params, data, is_jwt_requested); - is_first_attempt = true; - } + LOG_ERROR(WebService, "{} to {} returned error status code: {}", method, host + path, + response.status); return Common::WebResult{Common::WebResult::Code::HttpError, std::to_string(response.status)}; } auto content_type = response.headers.find("content-type"); - if (content_type == response.headers.end() || - (content_type->second.find("application/json") == std::string::npos && - content_type->second.find("text/html; charset=utf-8") == std::string::npos)) { - LOG_ERROR(WebService, "POST to {} returned wrong content: {}", url, content_type->second); + if (content_type == response.headers.end()) { + LOG_ERROR(WebService, "{} to {} returned no content", method, host + path); return Common::WebResult{Common::WebResult::Code::WrongContent, ""}; } + if (content_type->second.find("application/json") == std::string::npos && + content_type->second.find("text/html; charset=utf-8") == std::string::npos) { + LOG_ERROR(WebService, "{} to {} returned wrong content: {}", method, host + path, + content_type->second); + return Common::WebResult{Common::WebResult::Code::WrongContent, "Wrong content"}; + } return Common::WebResult{Common::WebResult::Code::Success, "", response.body}; } -std::future PostJson(const std::string& url, const std::string& data, - bool allow_anonymous) { - - using lup = LUrlParser::clParseURL; - namespace hl = httplib; - - lup parsedUrl = lup::ParseURL(url); - - if (url.empty() || !parsedUrl.IsValid()) { - LOG_ERROR(WebService, "URL is invalid"); - return std::async(std::launch::deferred, [] { - return Common::WebResult{Common::WebResult::Code::InvalidURL, "URL is invalid"}; - }); +void Client::UpdateJWT() { + if (!username.empty() && !token.empty()) { + auto result = GenericJson("POST", "/jwt/internal", "", "", username, token); + if (result.result_code != Common::WebResult::Code::Success) { + LOG_ERROR(WebService, "UpdateJWT failed"); + } else { + jwt_cache.username = username; + jwt_cache.token = token; + jwt_cache.jwt = jwt = result.returned_data; + } } - - const std::string jwt = - UpdateCoreJWT(false, Settings::values.citra_username, Settings::values.citra_token); - - const bool are_credentials_provided{!jwt.empty()}; - if (!allow_anonymous && !are_credentials_provided) { - LOG_ERROR(WebService, "Credentials must be provided for authenticated requests"); - return std::async(std::launch::deferred, [] { - return Common::WebResult{Common::WebResult::Code::CredentialsMissing, - "Credentials needed"}; - }); - } - - // Built request header - hl::Headers params; - if (are_credentials_provided) { - // Authenticated request if credentials are provided - params = {{std::string("Authorization"), fmt::format("Bearer {}", jwt)}, - {std::string("api-version"), std::string(API_VERSION)}, - {std::string("Content-Type"), std::string("application/json")}}; - } else { - // Otherwise, anonymous request - params = {{std::string("api-version"), std::string(API_VERSION)}, - {std::string("Content-Type"), std::string("application/json")}}; - } - - // Post JSON asynchronously - return std::async(std::launch::async, PostJsonAsyncFn, url, parsedUrl, params, data, false); } -std::future PostJson(const std::string& url, const std::string& username, - const std::string& token) { - using lup = LUrlParser::clParseURL; - namespace hl = httplib; - - lup parsedUrl = lup::ParseURL(url); - - if (url.empty() || !parsedUrl.IsValid()) { - LOG_ERROR(WebService, "URL is invalid"); - return std::async(std::launch::deferred, [] { - return Common::WebResult{Common::WebResult::Code::InvalidURL, ""}; - }); +Common::WebResult Client::GenericJson(const std::string& method, const std::string& path, + const std::string& data, bool allow_anonymous) { + if (jwt.empty()) { + UpdateJWT(); } - const bool are_credentials_provided{!token.empty() && !username.empty()}; - if (!are_credentials_provided) { + if (jwt.empty() && !allow_anonymous) { LOG_ERROR(WebService, "Credentials must be provided for authenticated requests"); - return std::async(std::launch::deferred, [] { - return Common::WebResult{Common::WebResult::Code::CredentialsMissing, ""}; - }); + return Common::WebResult{Common::WebResult::Code::CredentialsMissing, "Credentials needed"}; } - // Built request header - hl::Headers params; - if (are_credentials_provided) { - // Authenticated request if credentials are provided - params = {{std::string("x-username"), username}, - {std::string("x-token"), token}, - {std::string("api-version"), std::string(API_VERSION)}, - {std::string("Content-Type"), std::string("application/json")}}; - } else { - // Otherwise, anonymous request - params = {{std::string("api-version"), std::string(API_VERSION)}, - {std::string("Content-Type"), std::string("application/json")}}; + auto result = GenericJson(method, path, data, jwt); + if (result.result_string == "401") { + // Try again with new JWT + UpdateJWT(); + result = GenericJson(method, path, data, jwt); } - // Post JSON asynchronously - return std::async(std::launch::async, PostJsonAsyncFn, url, parsedUrl, params, "", true); -} - -template -std::future GetJson(std::function func, const std::string& url, - bool allow_anonymous) { - static bool is_first_attempt = true; - - using lup = LUrlParser::clParseURL; - namespace hl = httplib; - - lup parsedUrl = lup::ParseURL(url); - - if (url.empty() || !parsedUrl.IsValid()) { - LOG_ERROR(WebService, "URL is invalid"); - return std::async(std::launch::deferred, [func{std::move(func)}]() { return func(""); }); - } - - const std::string jwt = - UpdateCoreJWT(false, Settings::values.citra_username, Settings::values.citra_token); - - const bool are_credentials_provided{!jwt.empty()}; - if (!allow_anonymous && !are_credentials_provided) { - LOG_ERROR(WebService, "Credentials must be provided for authenticated requests"); - return std::async(std::launch::deferred, [func{std::move(func)}]() { return func(""); }); - } - - // Built request header - hl::Headers params; - if (are_credentials_provided) { - params = {{std::string("Authorization"), fmt::format("Bearer {}", jwt)}, - {std::string("api-version"), std::string(API_VERSION)}}; - } else { - // Otherwise, anonymous request - params = {{std::string("api-version"), std::string(API_VERSION)}}; - } - - // Get JSON asynchronously - return std::async(std::launch::async, [func, url, parsedUrl, params, allow_anonymous] { - std::unique_ptr cli = GetClientFor(parsedUrl); - - if (cli == nullptr) { - return func(""); - } - - hl::Request request; - request.method = "GET"; - request.path = "/" + parsedUrl.m_Path; - request.headers = params; - - hl::Response response; - - if (!cli->send(request, response)) { - LOG_ERROR(WebService, "GET to {} returned null", url); - return func(""); - } - - if (response.status >= 400) { - LOG_ERROR(WebService, "GET to {} returned error status code: {}", url, response.status); - if (response.status == 401 && is_first_attempt) { - LOG_WARNING(WebService, "Requesting new JWT"); - UpdateCoreJWT(true, Settings::values.citra_username, Settings::values.citra_token); - is_first_attempt = false; - GetJson(func, url, allow_anonymous); - is_first_attempt = true; - } - return func(""); - } - - auto content_type = response.headers.find("content-type"); - - if (content_type == response.headers.end() || - content_type->second.find("application/json") == std::string::npos) { - LOG_ERROR(WebService, "GET to {} returned wrong content: {}", url, - content_type->second); - return func(""); - } - - return func(response.body); - }); -} - -template std::future GetJson(std::function func, - const std::string& url, bool allow_anonymous); -template std::future GetJson( - std::function func, - const std::string& url, bool allow_anonymous); - -void DeleteJson(const std::string& url, const std::string& data) { - static bool is_first_attempt = true; - - using lup = LUrlParser::clParseURL; - namespace hl = httplib; - - lup parsedUrl = lup::ParseURL(url); - - if (url.empty() || !parsedUrl.IsValid()) { - LOG_ERROR(WebService, "URL is invalid"); - return; - } - - const std::string jwt = - UpdateCoreJWT(false, Settings::values.citra_username, Settings::values.citra_token); - - const bool are_credentials_provided{!jwt.empty()}; - if (!are_credentials_provided) { - LOG_ERROR(WebService, "Credentials must be provided for authenticated requests"); - return; - } - - // Built request header - hl::Headers params = {{std::string("Authorization"), fmt::format("Bearer {}", jwt)}, - {std::string("api-version"), std::string(API_VERSION)}, - {std::string("Content-Type"), std::string("application/json")}}; - - // Delete JSON asynchronously - std::async(std::launch::async, [url, parsedUrl, params, data] { - std::unique_ptr cli = GetClientFor(parsedUrl); - - if (cli == nullptr) { - return; - } - - hl::Request request; - request.method = "DELETE"; - request.path = "/" + parsedUrl.m_Path; - request.headers = params; - request.body = data; - - hl::Response response; - - if (!cli->send(request, response)) { - LOG_ERROR(WebService, "DELETE to {} returned null", url); - return; - } - - if (response.status >= 400) { - LOG_ERROR(WebService, "DELETE to {} returned error status code: {}", url, - response.status); - if (response.status == 401 && is_first_attempt) { - LOG_WARNING(WebService, "Requesting new JWT"); - UpdateCoreJWT(true, Settings::values.citra_username, Settings::values.citra_token); - is_first_attempt = false; - DeleteJson(url, data); - is_first_attempt = true; - } - return; - } - - auto content_type = response.headers.find("content-type"); - - if (content_type == response.headers.end() || - content_type->second.find("application/json") == std::string::npos) { - LOG_ERROR(WebService, "DELETE to {} returned wrong content: {}", url, - content_type->second); - return; - } - - return; - }); + return result; } } // namespace WebService diff --git a/src/web_service/web_backend.h b/src/web_service/web_backend.h index d00095c91..be115f96e 100644 --- a/src/web_service/web_backend.h +++ b/src/web_service/web_backend.h @@ -12,72 +12,80 @@ #include "common/announce_multiplayer_room.h" #include "common/common_types.h" -namespace LUrlParser { -class clParseURL; +namespace httplib { +class Client; } namespace WebService { -/** - * Requests a new JWT if necessary - * @param force_new_token If true, force to request a new token from the server. - * @param username Citra username to use for authentication. - * @param token Citra token to use for authentication. - * @return string with the current JWT toke - */ -std::string UpdateCoreJWT(bool force_new_token, const std::string& username, - const std::string& token); +class Client { +public: + Client(const std::string& host, const std::string& username, const std::string& token); -/** - * Posts JSON to a api.citra-emu.org. - * @param url URL of the api.citra-emu.org endpoint to post data to. - * @param parsed_url Parsed URL used for the POST request. - * @param params Headers sent for the POST request. - * @param data String of JSON data to use for the body of the POST request. - * @param data If true, a JWT is requested in the function - * @return future with the returned value of the POST - */ -static Common::WebResult PostJsonAsyncFn(const std::string& url, - const LUrlParser::clParseURL& parsed_url, - const httplib::Headers& params, const std::string& data, - bool is_jwt_requested); + /** + * Posts JSON to the specified path. + * @param path the URL segment after the host address. + * @param data String of JSON data to use for the body of the POST request. + * @param allow_anonymous If true, allow anonymous unauthenticated requests. + * @return the result of the request. + */ + Common::WebResult PostJson(const std::string& path, const std::string& data, + bool allow_anonymous) { + return GenericJson("POST", path, data, allow_anonymous); + } -/** - * Posts JSON to api.citra-emu.org. - * @param url URL of the api.citra-emu.org endpoint to post data to. - * @param data String of JSON data to use for the body of the POST request. - * @param allow_anonymous If true, allow anonymous unauthenticated requests. - * @return future with the returned value of the POST - */ -std::future PostJson(const std::string& url, const std::string& data, - bool allow_anonymous); + /** + * Gets JSON from the specified path. + * @param path the URL segment after the host address. + * @param allow_anonymous If true, allow anonymous unauthenticated requests. + * @return the result of the request. + */ + Common::WebResult GetJson(const std::string& path, bool allow_anonymous) { + return GenericJson("GET", path, "", allow_anonymous); + } -/** - * Posts JSON to api.citra-emu.org. - * @param url URL of the api.citra-emu.org endpoint to post data to. - * @param username Citra username to use for authentication. - * @param token Citra token to use for authentication. - * @return future with the error or result of the POST - */ -std::future PostJson(const std::string& url, const std::string& username, - const std::string& token); + /** + * Deletes JSON to the specified path. + * @param path the URL segment after the host address. + * @param data String of JSON data to use for the body of the DELETE request. + * @param allow_anonymous If true, allow anonymous unauthenticated requests. + * @return the result of the request. + */ + Common::WebResult DeleteJson(const std::string& path, const std::string& data, + bool allow_anonymous) { + return GenericJson("DELETE", path, data, allow_anonymous); + } -/** - * Gets JSON from api.citra-emu.org. - * @param func A function that gets exectued when the json as a string is received - * @param url URL of the api.citra-emu.org endpoint to post data to. - * @param allow_anonymous If true, allow anonymous unauthenticated requests. - * @return future that holds the return value T of the func - */ -template -std::future GetJson(std::function func, const std::string& url, - bool allow_anonymous); +private: + /// A generic function handles POST, GET and DELETE request together + Common::WebResult GenericJson(const std::string& method, const std::string& path, + const std::string& data, bool allow_anonymous); -/** - * Delete JSON to api.citra-emu.org. - * @param url URL of the api.citra-emu.org endpoint to post data to. - * @param data String of JSON data to use for the body of the DELETE request. - */ -void DeleteJson(const std::string& url, const std::string& data); + /** + * A generic function with explicit authentication method specified + * JWT is used if the jwt parameter is not empty + * username + token is used if jwt is empty but username and token are not empty + * anonymous if all of jwt, username and token are empty + */ + Common::WebResult GenericJson(const std::string& method, const std::string& path, + const std::string& data, const std::string& jwt = "", + const std::string& username = "", const std::string& token = ""); + + // Retrieve a new JWT from given username and token + void UpdateJWT(); + + std::string host; + std::string username; + std::string token; + std::string jwt; + std::unique_ptr cli; + + struct JWTCache { + std::string username; + std::string token; + std::string jwt; + }; + static JWTCache jwt_cache; +}; } // namespace WebService From f3d59556efd7e4930eb2633277ebe9f969d5cd0a Mon Sep 17 00:00:00 2001 From: Weiyi Wang Date: Mon, 17 Sep 2018 14:28:58 -0400 Subject: [PATCH 2/2] web_backend: protect jwt cache with a mutex --- src/web_service/web_backend.cpp | 2 ++ src/web_service/web_backend.h | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/web_service/web_backend.cpp b/src/web_service/web_backend.cpp index e935d8b15..3c5d278e8 100644 --- a/src/web_service/web_backend.cpp +++ b/src/web_service/web_backend.cpp @@ -24,6 +24,7 @@ Client::JWTCache Client::jwt_cache{}; Client::Client(const std::string& host, const std::string& username, const std::string& token) : host(host), username(username), token(token) { + std::lock_guard lock(jwt_cache.mutex); if (username == jwt_cache.username && token == jwt_cache.token) { jwt = jwt_cache.jwt; } @@ -116,6 +117,7 @@ void Client::UpdateJWT() { if (result.result_code != Common::WebResult::Code::Success) { LOG_ERROR(WebService, "UpdateJWT failed"); } else { + std::lock_guard lock(jwt_cache.mutex); jwt_cache.username = username; jwt_cache.token = token; jwt_cache.jwt = jwt = result.returned_data; diff --git a/src/web_service/web_backend.h b/src/web_service/web_backend.h index be115f96e..955b91c7a 100644 --- a/src/web_service/web_backend.h +++ b/src/web_service/web_backend.h @@ -5,7 +5,7 @@ #pragma once #include -#include +#include #include #include #include @@ -81,6 +81,7 @@ private: std::unique_ptr cli; struct JWTCache { + std::mutex mutex; std::string username; std::string token; std::string jwt;