summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/network/core/address.cpp11
-rw-r--r--src/network/core/address.h1
-rw-r--r--src/network/core/tcp.h14
-rw-r--r--src/network/core/tcp_connect.cpp75
-rw-r--r--src/network/network.cpp59
-rw-r--r--src/network/network_internal.h1
6 files changed, 133 insertions, 28 deletions
diff --git a/src/network/core/address.cpp b/src/network/core/address.cpp
index 113dae686..4c090c14a 100644
--- a/src/network/core/address.cpp
+++ b/src/network/core/address.cpp
@@ -414,17 +414,22 @@ void NetworkAddress::Listen(int socktype, SocketList *sockets)
}
/**
- * Convert a string containing either "hostname" or "hostname:ip" to a
- * ServerAddress, where the string can be postfixed with "#company" to
+ * Convert a string containing either "hostname", "hostname:port" or invite code
+ * to a ServerAddress, where the string can be postfixed with "#company" to
* indicate the requested company.
*
* @param connection_string The string to parse.
* @param default_port The default port to set port to if not in connection_string.
- * @param company Pointer to the company variable to set iff indicted.
+ * @param company Pointer to the company variable to set iff indicated.
* @return A valid ServerAddress of the parsed information.
*/
/* static */ ServerAddress ServerAddress::Parse(const std::string &connection_string, uint16 default_port, CompanyID *company_id)
{
+ if (StrStartsWith(connection_string, "+")) {
+ std::string_view invite_code = ParseCompanyFromConnectionString(connection_string, company_id);
+ return ServerAddress(SERVER_ADDRESS_INVITE_CODE, std::string(invite_code));
+ }
+
uint16 port = default_port;
std::string_view ip = ParseFullConnectionString(connection_string, port, company_id);
return ServerAddress(SERVER_ADDRESS_DIRECT, std::string(ip) + ":" + std::to_string(port));
diff --git a/src/network/core/address.h b/src/network/core/address.h
index b22bcac0e..9e09632d3 100644
--- a/src/network/core/address.h
+++ b/src/network/core/address.h
@@ -185,6 +185,7 @@ public:
*/
enum ServerAddressType {
SERVER_ADDRESS_DIRECT, ///< Server-address is based on an hostname:port.
+ SERVER_ADDRESS_INVITE_CODE, ///< Server-address is based on an invite code.
};
/**
diff --git a/src/network/core/tcp.h b/src/network/core/tcp.h
index 379ef8b92..bbd0bc2a9 100644
--- a/src/network/core/tcp.h
+++ b/src/network/core/tcp.h
@@ -82,10 +82,12 @@ private:
RESOLVING, ///< The hostname is being resolved (threaded).
FAILURE, ///< Resolving failed.
CONNECTING, ///< We are currently connecting.
+ CONNECTED, ///< The connection is established.
};
std::thread resolve_thread; ///< Thread used during resolving.
std::atomic<Status> status = Status::INIT; ///< The current status of the connecter.
+ std::atomic<bool> killed = false; ///< Whether this connecter is marked as killed.
addrinfo *ai = nullptr; ///< getaddrinfo() allocated linked-list of resolved addresses.
std::vector<addrinfo *> addresses; ///< Addresses we can connect to.
@@ -101,7 +103,7 @@ private:
void OnResolved(addrinfo *ai);
bool TryNextAddress();
void Connect(addrinfo *address);
- bool CheckActivity();
+ virtual bool CheckActivity();
/* We do not want any other derived classes from this class being able to
* access these private members, but it is okay for TCPServerConnecter. */
@@ -125,15 +127,25 @@ public:
*/
virtual void OnFailure() {}
+ void Kill();
+
static void CheckCallbacks();
static void KillAll();
};
class TCPServerConnecter : public TCPConnecter {
+private:
+ SOCKET socket = INVALID_SOCKET; ///< The socket when a connection is established.
+
+ bool CheckActivity() override;
+
public:
ServerAddress server_address; ///< Address we are connecting to.
TCPServerConnecter(const std::string &connection_string, uint16 default_port);
+
+ void SetConnected(SOCKET sock);
+ void SetFailure();
};
#endif /* NETWORK_CORE_TCP_H */
diff --git a/src/network/core/tcp_connect.cpp b/src/network/core/tcp_connect.cpp
index bb96e33f9..d9b6bb781 100644
--- a/src/network/core/tcp_connect.cpp
+++ b/src/network/core/tcp_connect.cpp
@@ -46,6 +46,12 @@ TCPServerConnecter::TCPServerConnecter(const std::string &connection_string, uin
this->connection_string = this->server_address.connection_string;
break;
+ case SERVER_ADDRESS_INVITE_CODE:
+ this->status = Status::CONNECTING;
+
+ // TODO -- The next commit will add this functionality.
+ break;
+
default:
NOT_REACHED();
}
@@ -69,6 +75,16 @@ TCPConnecter::~TCPConnecter()
}
/**
+ * Kill this connecter.
+ * It will abort as soon as it can and not call any of the callbacks.
+ */
+void TCPConnecter::Kill()
+{
+ /* Delay the removing of the socket till the next CheckActivity(). */
+ this->killed = true;
+}
+
+/**
* Start a connection to the indicated address.
* @param address The address to connection to.
*/
@@ -239,7 +255,9 @@ void TCPConnecter::Resolve()
*/
bool TCPConnecter::CheckActivity()
{
- switch (this->status.load()) {
+ if (this->killed) return true;
+
+ switch (this->status) {
case Status::INIT:
/* Start the thread delayed, so the vtable is loaded. This allows classes
* to overload functions used by Resolve() (in case threading is disabled). */
@@ -266,6 +284,7 @@ bool TCPConnecter::CheckActivity()
return true;
case Status::CONNECTING:
+ case Status::CONNECTED:
break;
}
@@ -364,10 +383,64 @@ bool TCPConnecter::CheckActivity()
}
this->OnConnect(connected_socket);
+ this->status = Status::CONNECTED;
return true;
}
/**
+ * Check if there was activity for this connecter.
+ * @return True iff the TCPConnecter is done and can be cleaned up.
+ */
+bool TCPServerConnecter::CheckActivity()
+{
+ if (this->killed) return true;
+
+ switch (this->server_address.type) {
+ case SERVER_ADDRESS_DIRECT:
+ return TCPConnecter::CheckActivity();
+
+ case SERVER_ADDRESS_INVITE_CODE:
+ /* Check if a result has come in. */
+ switch (this->status) {
+ case Status::FAILURE:
+ this->OnFailure();
+ return true;
+
+ case Status::CONNECTED:
+ this->OnConnect(this->socket);
+ return true;
+
+ default:
+ break;
+ }
+
+ return false;
+
+ default:
+ NOT_REACHED();
+ }
+}
+
+/**
+ * The connection was successfully established.
+ * This socket is fully setup and ready to send/recv game protocol packets.
+ * @param sock The socket of the established connection.
+ */
+void TCPServerConnecter::SetConnected(SOCKET sock)
+{
+ this->socket = sock;
+ this->status = Status::CONNECTED;
+}
+
+/**
+ * The connection couldn't be established.
+ */
+void TCPServerConnecter::SetFailure()
+{
+ this->status = Status::FAILURE;
+}
+
+/**
* Check whether we need to call the callback, i.e. whether we
* have connected or aborted and call the appropriate callback
* for that. It's done this way to ease on the locking that
diff --git a/src/network/network.cpp b/src/network/network.cpp
index 4e48ce351..3a33e5096 100644
--- a/src/network/network.cpp
+++ b/src/network/network.cpp
@@ -453,6 +453,41 @@ static void CheckPauseOnJoin()
}
/**
+ * Parse the company part ("#company" postfix) of a connecting string.
+ * @param connection_string The string with the connection data.
+ * @param company_id The company ID to set, if available.
+ * @return A std::string_view into the connection string without the company part.
+ */
+std::string_view ParseCompanyFromConnectionString(const std::string &connection_string, CompanyID *company_id)
+{
+ std::string_view ip = connection_string;
+ if (company_id == nullptr) return ip;
+
+ size_t offset = ip.find_last_of('#');
+ if (offset != std::string::npos) {
+ std::string_view company_string = ip.substr(offset + 1);
+ ip = ip.substr(0, offset);
+
+ uint8 company_value;
+ auto [_, err] = std::from_chars(company_string.data(), company_string.data() + company_string.size(), company_value);
+ if (err == std::errc()) {
+ if (company_value != COMPANY_NEW_COMPANY && company_value != COMPANY_SPECTATOR) {
+ if (company_value > MAX_COMPANIES || company_value == 0) {
+ *company_id = COMPANY_SPECTATOR;
+ } else {
+ /* "#1" means the first company, which has index 0. */
+ *company_id = (CompanyID)(company_value - 1);
+ }
+ } else {
+ *company_id = (CompanyID)company_value;
+ }
+ }
+ }
+
+ return ip;
+}
+
+/**
* Converts a string to ip/port/company
* Format: IP:port#company
*
@@ -469,29 +504,7 @@ static void CheckPauseOnJoin()
*/
std::string_view ParseFullConnectionString(const std::string &connection_string, uint16 &port, CompanyID *company_id)
{
- std::string_view ip = connection_string;
- if (company_id != nullptr) {
- size_t offset = ip.find_last_of('#');
- if (offset != std::string::npos) {
- std::string_view company_string = ip.substr(offset + 1);
- ip = ip.substr(0, offset);
-
- uint8 company_value;
- auto [_, err] = std::from_chars(company_string.data(), company_string.data() + company_string.size(), company_value);
- if (err == std::errc()) {
- if (company_value != COMPANY_NEW_COMPANY && company_value != COMPANY_SPECTATOR) {
- if (company_value > MAX_COMPANIES || company_value == 0) {
- *company_id = COMPANY_SPECTATOR;
- } else {
- /* "#1" means the first company, which has index 0. */
- *company_id = (CompanyID)(company_value - 1);
- }
- } else {
- *company_id = (CompanyID)company_value;
- }
- }
- }
- }
+ std::string_view ip = ParseCompanyFromConnectionString(connection_string, company_id);
size_t port_offset = ip.find_last_of(':');
size_t ipv6_close = ip.find_last_of(']');
diff --git a/src/network/network_internal.h b/src/network/network_internal.h
index c1e6aa7b9..95226286c 100644
--- a/src/network/network_internal.h
+++ b/src/network/network_internal.h
@@ -122,6 +122,7 @@ StringID GetNetworkErrorMsg(NetworkErrorCode err);
bool NetworkMakeClientNameUnique(std::string &new_name);
std::string GenerateCompanyPasswordHash(const std::string &password, const std::string &password_server_id, uint32 password_game_seed);
+std::string_view ParseCompanyFromConnectionString(const std::string &connection_string, CompanyID *company_id);
NetworkAddress ParseConnectionString(const std::string &connection_string, uint16 default_port);
std::string NormalizeConnectionString(const std::string &connection_string, uint16 default_port);