From 1baec41542780cf4fc898df7d2fc9925d823fb44 Mon Sep 17 00:00:00 2001 From: Patric Stout Date: Sun, 11 Jul 2021 12:08:03 +0200 Subject: Change: groundwork to allow ServerAddress to use invite codes Normally TCPConnecter will do a DNS resolving of the connection_string and connect to it. But for SERVER_ADDRESS_INVITE_CODE this is different: the Game Coordinator does the "resolving". This means we need to allow TCPConnecter to not setup a connection and allow it to be told when a connection has been setup by an external (to TCPConnecter) part of the code. We do this by telling the (active) socket for the connection. This means the rest of the code doesn't need to know the TCPConnecter is not doing a simple resolve+connect. The rest of the code only cares the connection is established; not how it was established. --- src/network/core/address.cpp | 11 ++++-- src/network/core/address.h | 1 + src/network/core/tcp.h | 14 +++++++- src/network/core/tcp_connect.cpp | 75 +++++++++++++++++++++++++++++++++++++++- src/network/network.cpp | 59 +++++++++++++++++++------------ src/network/network_internal.h | 1 + 6 files changed, 133 insertions(+), 28 deletions(-) (limited to 'src') diff --git a/src/network/core/address.cpp b/src/network/core/address.cpp index 113dae686..4c090c14a 100644 --- a/src/network/core/address.cpp +++ b/src/network/core/address.cpp @@ -414,17 +414,22 @@ void NetworkAddress::Listen(int socktype, SocketList *sockets) } /** - * Convert a string containing either "hostname" or "hostname:ip" to a - * ServerAddress, where the string can be postfixed with "#company" to + * Convert a string containing either "hostname", "hostname:port" or invite code + * to a ServerAddress, where the string can be postfixed with "#company" to * indicate the requested company. * * @param connection_string The string to parse. * @param default_port The default port to set port to if not in connection_string. - * @param company Pointer to the company variable to set iff indicted. + * @param company Pointer to the company variable to set iff indicated. * @return A valid ServerAddress of the parsed information. */ /* static */ ServerAddress ServerAddress::Parse(const std::string &connection_string, uint16 default_port, CompanyID *company_id) { + if (StrStartsWith(connection_string, "+")) { + std::string_view invite_code = ParseCompanyFromConnectionString(connection_string, company_id); + return ServerAddress(SERVER_ADDRESS_INVITE_CODE, std::string(invite_code)); + } + uint16 port = default_port; std::string_view ip = ParseFullConnectionString(connection_string, port, company_id); return ServerAddress(SERVER_ADDRESS_DIRECT, std::string(ip) + ":" + std::to_string(port)); diff --git a/src/network/core/address.h b/src/network/core/address.h index b22bcac0e..9e09632d3 100644 --- a/src/network/core/address.h +++ b/src/network/core/address.h @@ -185,6 +185,7 @@ public: */ enum ServerAddressType { SERVER_ADDRESS_DIRECT, ///< Server-address is based on an hostname:port. + SERVER_ADDRESS_INVITE_CODE, ///< Server-address is based on an invite code. }; /** diff --git a/src/network/core/tcp.h b/src/network/core/tcp.h index 379ef8b92..bbd0bc2a9 100644 --- a/src/network/core/tcp.h +++ b/src/network/core/tcp.h @@ -82,10 +82,12 @@ private: RESOLVING, ///< The hostname is being resolved (threaded). FAILURE, ///< Resolving failed. CONNECTING, ///< We are currently connecting. + CONNECTED, ///< The connection is established. }; std::thread resolve_thread; ///< Thread used during resolving. std::atomic status = Status::INIT; ///< The current status of the connecter. + std::atomic killed = false; ///< Whether this connecter is marked as killed. addrinfo *ai = nullptr; ///< getaddrinfo() allocated linked-list of resolved addresses. std::vector addresses; ///< Addresses we can connect to. @@ -101,7 +103,7 @@ private: void OnResolved(addrinfo *ai); bool TryNextAddress(); void Connect(addrinfo *address); - bool CheckActivity(); + virtual bool CheckActivity(); /* We do not want any other derived classes from this class being able to * access these private members, but it is okay for TCPServerConnecter. */ @@ -125,15 +127,25 @@ public: */ virtual void OnFailure() {} + void Kill(); + static void CheckCallbacks(); static void KillAll(); }; class TCPServerConnecter : public TCPConnecter { +private: + SOCKET socket = INVALID_SOCKET; ///< The socket when a connection is established. + + bool CheckActivity() override; + public: ServerAddress server_address; ///< Address we are connecting to. TCPServerConnecter(const std::string &connection_string, uint16 default_port); + + void SetConnected(SOCKET sock); + void SetFailure(); }; #endif /* NETWORK_CORE_TCP_H */ diff --git a/src/network/core/tcp_connect.cpp b/src/network/core/tcp_connect.cpp index bb96e33f9..d9b6bb781 100644 --- a/src/network/core/tcp_connect.cpp +++ b/src/network/core/tcp_connect.cpp @@ -46,6 +46,12 @@ TCPServerConnecter::TCPServerConnecter(const std::string &connection_string, uin this->connection_string = this->server_address.connection_string; break; + case SERVER_ADDRESS_INVITE_CODE: + this->status = Status::CONNECTING; + + // TODO -- The next commit will add this functionality. + break; + default: NOT_REACHED(); } @@ -68,6 +74,16 @@ TCPConnecter::~TCPConnecter() if (this->ai != nullptr) freeaddrinfo(this->ai); } +/** + * Kill this connecter. + * It will abort as soon as it can and not call any of the callbacks. + */ +void TCPConnecter::Kill() +{ + /* Delay the removing of the socket till the next CheckActivity(). */ + this->killed = true; +} + /** * Start a connection to the indicated address. * @param address The address to connection to. @@ -239,7 +255,9 @@ void TCPConnecter::Resolve() */ bool TCPConnecter::CheckActivity() { - switch (this->status.load()) { + if (this->killed) return true; + + switch (this->status) { case Status::INIT: /* Start the thread delayed, so the vtable is loaded. This allows classes * to overload functions used by Resolve() (in case threading is disabled). */ @@ -266,6 +284,7 @@ bool TCPConnecter::CheckActivity() return true; case Status::CONNECTING: + case Status::CONNECTED: break; } @@ -364,9 +383,63 @@ bool TCPConnecter::CheckActivity() } this->OnConnect(connected_socket); + this->status = Status::CONNECTED; return true; } +/** + * Check if there was activity for this connecter. + * @return True iff the TCPConnecter is done and can be cleaned up. + */ +bool TCPServerConnecter::CheckActivity() +{ + if (this->killed) return true; + + switch (this->server_address.type) { + case SERVER_ADDRESS_DIRECT: + return TCPConnecter::CheckActivity(); + + case SERVER_ADDRESS_INVITE_CODE: + /* Check if a result has come in. */ + switch (this->status) { + case Status::FAILURE: + this->OnFailure(); + return true; + + case Status::CONNECTED: + this->OnConnect(this->socket); + return true; + + default: + break; + } + + return false; + + default: + NOT_REACHED(); + } +} + +/** + * The connection was successfully established. + * This socket is fully setup and ready to send/recv game protocol packets. + * @param sock The socket of the established connection. + */ +void TCPServerConnecter::SetConnected(SOCKET sock) +{ + this->socket = sock; + this->status = Status::CONNECTED; +} + +/** + * The connection couldn't be established. + */ +void TCPServerConnecter::SetFailure() +{ + this->status = Status::FAILURE; +} + /** * Check whether we need to call the callback, i.e. whether we * have connected or aborted and call the appropriate callback diff --git a/src/network/network.cpp b/src/network/network.cpp index 4e48ce351..3a33e5096 100644 --- a/src/network/network.cpp +++ b/src/network/network.cpp @@ -452,6 +452,41 @@ static void CheckPauseOnJoin() CheckPauseHelper(NetworkHasJoiningClient(), PM_PAUSED_JOIN); } +/** + * Parse the company part ("#company" postfix) of a connecting string. + * @param connection_string The string with the connection data. + * @param company_id The company ID to set, if available. + * @return A std::string_view into the connection string without the company part. + */ +std::string_view ParseCompanyFromConnectionString(const std::string &connection_string, CompanyID *company_id) +{ + std::string_view ip = connection_string; + if (company_id == nullptr) return ip; + + size_t offset = ip.find_last_of('#'); + if (offset != std::string::npos) { + std::string_view company_string = ip.substr(offset + 1); + ip = ip.substr(0, offset); + + uint8 company_value; + auto [_, err] = std::from_chars(company_string.data(), company_string.data() + company_string.size(), company_value); + if (err == std::errc()) { + if (company_value != COMPANY_NEW_COMPANY && company_value != COMPANY_SPECTATOR) { + if (company_value > MAX_COMPANIES || company_value == 0) { + *company_id = COMPANY_SPECTATOR; + } else { + /* "#1" means the first company, which has index 0. */ + *company_id = (CompanyID)(company_value - 1); + } + } else { + *company_id = (CompanyID)company_value; + } + } + } + + return ip; +} + /** * Converts a string to ip/port/company * Format: IP:port#company @@ -469,29 +504,7 @@ static void CheckPauseOnJoin() */ std::string_view ParseFullConnectionString(const std::string &connection_string, uint16 &port, CompanyID *company_id) { - std::string_view ip = connection_string; - if (company_id != nullptr) { - size_t offset = ip.find_last_of('#'); - if (offset != std::string::npos) { - std::string_view company_string = ip.substr(offset + 1); - ip = ip.substr(0, offset); - - uint8 company_value; - auto [_, err] = std::from_chars(company_string.data(), company_string.data() + company_string.size(), company_value); - if (err == std::errc()) { - if (company_value != COMPANY_NEW_COMPANY && company_value != COMPANY_SPECTATOR) { - if (company_value > MAX_COMPANIES || company_value == 0) { - *company_id = COMPANY_SPECTATOR; - } else { - /* "#1" means the first company, which has index 0. */ - *company_id = (CompanyID)(company_value - 1); - } - } else { - *company_id = (CompanyID)company_value; - } - } - } - } + std::string_view ip = ParseCompanyFromConnectionString(connection_string, company_id); size_t port_offset = ip.find_last_of(':'); size_t ipv6_close = ip.find_last_of(']'); diff --git a/src/network/network_internal.h b/src/network/network_internal.h index c1e6aa7b9..95226286c 100644 --- a/src/network/network_internal.h +++ b/src/network/network_internal.h @@ -122,6 +122,7 @@ StringID GetNetworkErrorMsg(NetworkErrorCode err); bool NetworkMakeClientNameUnique(std::string &new_name); std::string GenerateCompanyPasswordHash(const std::string &password, const std::string &password_server_id, uint32 password_game_seed); +std::string_view ParseCompanyFromConnectionString(const std::string &connection_string, CompanyID *company_id); NetworkAddress ParseConnectionString(const std::string &connection_string, uint16 default_port); std::string NormalizeConnectionString(const std::string &connection_string, uint16 default_port); -- cgit v1.2.3-70-g09d2