diff options
-rw-r--r-- | src/network/core/address.cpp | 11 | ||||
-rw-r--r-- | src/network/core/address.h | 1 | ||||
-rw-r--r-- | src/network/core/tcp.h | 14 | ||||
-rw-r--r-- | src/network/core/tcp_connect.cpp | 75 | ||||
-rw-r--r-- | src/network/network.cpp | 59 | ||||
-rw-r--r-- | src/network/network_internal.h | 1 |
6 files changed, 133 insertions, 28 deletions
diff --git a/src/network/core/address.cpp b/src/network/core/address.cpp index 113dae686..4c090c14a 100644 --- a/src/network/core/address.cpp +++ b/src/network/core/address.cpp @@ -414,17 +414,22 @@ void NetworkAddress::Listen(int socktype, SocketList *sockets) } /** - * Convert a string containing either "hostname" or "hostname:ip" to a - * ServerAddress, where the string can be postfixed with "#company" to + * Convert a string containing either "hostname", "hostname:port" or invite code + * to a ServerAddress, where the string can be postfixed with "#company" to * indicate the requested company. * * @param connection_string The string to parse. * @param default_port The default port to set port to if not in connection_string. - * @param company Pointer to the company variable to set iff indicted. + * @param company Pointer to the company variable to set iff indicated. * @return A valid ServerAddress of the parsed information. */ /* static */ ServerAddress ServerAddress::Parse(const std::string &connection_string, uint16 default_port, CompanyID *company_id) { + if (StrStartsWith(connection_string, "+")) { + std::string_view invite_code = ParseCompanyFromConnectionString(connection_string, company_id); + return ServerAddress(SERVER_ADDRESS_INVITE_CODE, std::string(invite_code)); + } + uint16 port = default_port; std::string_view ip = ParseFullConnectionString(connection_string, port, company_id); return ServerAddress(SERVER_ADDRESS_DIRECT, std::string(ip) + ":" + std::to_string(port)); diff --git a/src/network/core/address.h b/src/network/core/address.h index b22bcac0e..9e09632d3 100644 --- a/src/network/core/address.h +++ b/src/network/core/address.h @@ -185,6 +185,7 @@ public: */ enum ServerAddressType { SERVER_ADDRESS_DIRECT, ///< Server-address is based on an hostname:port. + SERVER_ADDRESS_INVITE_CODE, ///< Server-address is based on an invite code. }; /** diff --git a/src/network/core/tcp.h b/src/network/core/tcp.h index 379ef8b92..bbd0bc2a9 100644 --- a/src/network/core/tcp.h +++ b/src/network/core/tcp.h @@ -82,10 +82,12 @@ private: RESOLVING, ///< The hostname is being resolved (threaded). FAILURE, ///< Resolving failed. CONNECTING, ///< We are currently connecting. + CONNECTED, ///< The connection is established. }; std::thread resolve_thread; ///< Thread used during resolving. std::atomic<Status> status = Status::INIT; ///< The current status of the connecter. + std::atomic<bool> killed = false; ///< Whether this connecter is marked as killed. addrinfo *ai = nullptr; ///< getaddrinfo() allocated linked-list of resolved addresses. std::vector<addrinfo *> addresses; ///< Addresses we can connect to. @@ -101,7 +103,7 @@ private: void OnResolved(addrinfo *ai); bool TryNextAddress(); void Connect(addrinfo *address); - bool CheckActivity(); + virtual bool CheckActivity(); /* We do not want any other derived classes from this class being able to * access these private members, but it is okay for TCPServerConnecter. */ @@ -125,15 +127,25 @@ public: */ virtual void OnFailure() {} + void Kill(); + static void CheckCallbacks(); static void KillAll(); }; class TCPServerConnecter : public TCPConnecter { +private: + SOCKET socket = INVALID_SOCKET; ///< The socket when a connection is established. + + bool CheckActivity() override; + public: ServerAddress server_address; ///< Address we are connecting to. TCPServerConnecter(const std::string &connection_string, uint16 default_port); + + void SetConnected(SOCKET sock); + void SetFailure(); }; #endif /* NETWORK_CORE_TCP_H */ diff --git a/src/network/core/tcp_connect.cpp b/src/network/core/tcp_connect.cpp index bb96e33f9..d9b6bb781 100644 --- a/src/network/core/tcp_connect.cpp +++ b/src/network/core/tcp_connect.cpp @@ -46,6 +46,12 @@ TCPServerConnecter::TCPServerConnecter(const std::string &connection_string, uin this->connection_string = this->server_address.connection_string; break; + case SERVER_ADDRESS_INVITE_CODE: + this->status = Status::CONNECTING; + + // TODO -- The next commit will add this functionality. + break; + default: NOT_REACHED(); } @@ -69,6 +75,16 @@ TCPConnecter::~TCPConnecter() } /** + * Kill this connecter. + * It will abort as soon as it can and not call any of the callbacks. + */ +void TCPConnecter::Kill() +{ + /* Delay the removing of the socket till the next CheckActivity(). */ + this->killed = true; +} + +/** * Start a connection to the indicated address. * @param address The address to connection to. */ @@ -239,7 +255,9 @@ void TCPConnecter::Resolve() */ bool TCPConnecter::CheckActivity() { - switch (this->status.load()) { + if (this->killed) return true; + + switch (this->status) { case Status::INIT: /* Start the thread delayed, so the vtable is loaded. This allows classes * to overload functions used by Resolve() (in case threading is disabled). */ @@ -266,6 +284,7 @@ bool TCPConnecter::CheckActivity() return true; case Status::CONNECTING: + case Status::CONNECTED: break; } @@ -364,10 +383,64 @@ bool TCPConnecter::CheckActivity() } this->OnConnect(connected_socket); + this->status = Status::CONNECTED; return true; } /** + * Check if there was activity for this connecter. + * @return True iff the TCPConnecter is done and can be cleaned up. + */ +bool TCPServerConnecter::CheckActivity() +{ + if (this->killed) return true; + + switch (this->server_address.type) { + case SERVER_ADDRESS_DIRECT: + return TCPConnecter::CheckActivity(); + + case SERVER_ADDRESS_INVITE_CODE: + /* Check if a result has come in. */ + switch (this->status) { + case Status::FAILURE: + this->OnFailure(); + return true; + + case Status::CONNECTED: + this->OnConnect(this->socket); + return true; + + default: + break; + } + + return false; + + default: + NOT_REACHED(); + } +} + +/** + * The connection was successfully established. + * This socket is fully setup and ready to send/recv game protocol packets. + * @param sock The socket of the established connection. + */ +void TCPServerConnecter::SetConnected(SOCKET sock) +{ + this->socket = sock; + this->status = Status::CONNECTED; +} + +/** + * The connection couldn't be established. + */ +void TCPServerConnecter::SetFailure() +{ + this->status = Status::FAILURE; +} + +/** * Check whether we need to call the callback, i.e. whether we * have connected or aborted and call the appropriate callback * for that. It's done this way to ease on the locking that diff --git a/src/network/network.cpp b/src/network/network.cpp index 4e48ce351..3a33e5096 100644 --- a/src/network/network.cpp +++ b/src/network/network.cpp @@ -453,6 +453,41 @@ static void CheckPauseOnJoin() } /** + * Parse the company part ("#company" postfix) of a connecting string. + * @param connection_string The string with the connection data. + * @param company_id The company ID to set, if available. + * @return A std::string_view into the connection string without the company part. + */ +std::string_view ParseCompanyFromConnectionString(const std::string &connection_string, CompanyID *company_id) +{ + std::string_view ip = connection_string; + if (company_id == nullptr) return ip; + + size_t offset = ip.find_last_of('#'); + if (offset != std::string::npos) { + std::string_view company_string = ip.substr(offset + 1); + ip = ip.substr(0, offset); + + uint8 company_value; + auto [_, err] = std::from_chars(company_string.data(), company_string.data() + company_string.size(), company_value); + if (err == std::errc()) { + if (company_value != COMPANY_NEW_COMPANY && company_value != COMPANY_SPECTATOR) { + if (company_value > MAX_COMPANIES || company_value == 0) { + *company_id = COMPANY_SPECTATOR; + } else { + /* "#1" means the first company, which has index 0. */ + *company_id = (CompanyID)(company_value - 1); + } + } else { + *company_id = (CompanyID)company_value; + } + } + } + + return ip; +} + +/** * Converts a string to ip/port/company * Format: IP:port#company * @@ -469,29 +504,7 @@ static void CheckPauseOnJoin() */ std::string_view ParseFullConnectionString(const std::string &connection_string, uint16 &port, CompanyID *company_id) { - std::string_view ip = connection_string; - if (company_id != nullptr) { - size_t offset = ip.find_last_of('#'); - if (offset != std::string::npos) { - std::string_view company_string = ip.substr(offset + 1); - ip = ip.substr(0, offset); - - uint8 company_value; - auto [_, err] = std::from_chars(company_string.data(), company_string.data() + company_string.size(), company_value); - if (err == std::errc()) { - if (company_value != COMPANY_NEW_COMPANY && company_value != COMPANY_SPECTATOR) { - if (company_value > MAX_COMPANIES || company_value == 0) { - *company_id = COMPANY_SPECTATOR; - } else { - /* "#1" means the first company, which has index 0. */ - *company_id = (CompanyID)(company_value - 1); - } - } else { - *company_id = (CompanyID)company_value; - } - } - } - } + std::string_view ip = ParseCompanyFromConnectionString(connection_string, company_id); size_t port_offset = ip.find_last_of(':'); size_t ipv6_close = ip.find_last_of(']'); diff --git a/src/network/network_internal.h b/src/network/network_internal.h index c1e6aa7b9..95226286c 100644 --- a/src/network/network_internal.h +++ b/src/network/network_internal.h @@ -122,6 +122,7 @@ StringID GetNetworkErrorMsg(NetworkErrorCode err); bool NetworkMakeClientNameUnique(std::string &new_name); std::string GenerateCompanyPasswordHash(const std::string &password, const std::string &password_server_id, uint32 password_game_seed); +std::string_view ParseCompanyFromConnectionString(const std::string &connection_string, CompanyID *company_id); NetworkAddress ParseConnectionString(const std::string &connection_string, uint16 default_port); std::string NormalizeConnectionString(const std::string &connection_string, uint16 default_port); |