diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/network/core/address.cpp | 91 | ||||
-rw-r--r-- | src/network/core/address.h | 2 | ||||
-rw-r--r-- | src/network/core/tcp.h | 28 | ||||
-rw-r--r-- | src/network/core/tcp_connect.cpp | 301 | ||||
-rw-r--r-- | src/network/core/tcp_http.cpp | 4 | ||||
-rw-r--r-- | src/network/core/tcp_http.h | 10 | ||||
-rw-r--r-- | src/network/network.cpp | 41 | ||||
-rw-r--r-- | src/network/network_internal.h | 1 |
8 files changed, 333 insertions, 145 deletions
diff --git a/src/network/core/address.cpp b/src/network/core/address.cpp index d40a72420..205b5df22 100644 --- a/src/network/core/address.cpp +++ b/src/network/core/address.cpp @@ -14,8 +14,6 @@ #include "../../safeguards.h" -static const int DEFAULT_CONNECT_TIMEOUT_SECONDS = 3; ///< Allow connect() three seconds to connect. - /** * Get the hostname; in case it wasn't given the * IPv4 dotted representation is given. @@ -307,82 +305,6 @@ SOCKET NetworkAddress::Resolve(int family, int socktype, int flags, SocketList * } /** - * Helper function to resolve a connected socket. - * @param runp information about the socket to try not - * @return the opened socket or INVALID_SOCKET - */ -static SOCKET ConnectLoopProc(addrinfo *runp) -{ - const char *type = NetworkAddress::SocketTypeAsString(runp->ai_socktype); - const char *family = NetworkAddress::AddressFamilyAsString(runp->ai_family); - std::string address = NetworkAddress(runp->ai_addr, (int)runp->ai_addrlen).GetAddressAsString(); - - SOCKET sock = socket(runp->ai_family, runp->ai_socktype, runp->ai_protocol); - if (sock == INVALID_SOCKET) { - DEBUG(net, 1, "[%s] could not create %s socket: %s", type, family, NetworkError::GetLast().AsString()); - return INVALID_SOCKET; - } - - if (!SetNoDelay(sock)) DEBUG(net, 1, "[%s] setting TCP_NODELAY failed", type); - - if (!SetNonBlocking(sock)) DEBUG(net, 0, "[%s] setting non-blocking mode failed", type); - - int err = connect(sock, runp->ai_addr, (int)runp->ai_addrlen); - if (err != 0 && !NetworkError::GetLast().IsConnectInProgress()) { - DEBUG(net, 1, "[%s] could not connect to %s over %s: %s", type, address.c_str(), family, NetworkError::GetLast().AsString()); - closesocket(sock); - return INVALID_SOCKET; - } - - fd_set write_fd; - struct timeval tv; - - FD_ZERO(&write_fd); - FD_SET(sock, &write_fd); - - /* Wait for connect() to either connect, timeout or fail. */ - tv.tv_usec = 0; - tv.tv_sec = DEFAULT_CONNECT_TIMEOUT_SECONDS; - int n = select(FD_SETSIZE, NULL, &write_fd, NULL, &tv); - if (n < 0) { - DEBUG(net, 1, "[%s] could not connect to %s: %s", type, address.c_str(), NetworkError::GetLast().AsString()); - closesocket(sock); - return INVALID_SOCKET; - } - - /* If no fd is selected, the timeout has been reached. */ - if (n == 0) { - DEBUG(net, 1, "[%s] timed out while connecting to %s", type, address.c_str()); - closesocket(sock); - return INVALID_SOCKET; - } - - /* Retrieve last error, if any, on the socket. */ - NetworkError socket_error = GetSocketError(sock); - if (socket_error.HasError()) { - DEBUG(net, 1, "[%s] could not connect to %s: %s", type, address.c_str(), socket_error.AsString()); - closesocket(sock); - return INVALID_SOCKET; - } - - /* Connection succeeded. */ - DEBUG(net, 1, "[%s] connected to %s", type, address.c_str()); - - return sock; -} - -/** - * Connect to the given address. - * @return the connected socket or INVALID_SOCKET. - */ -SOCKET NetworkAddress::Connect() -{ - DEBUG(net, 1, "Connecting to %s", this->GetAddressAsString().c_str()); - - return this->Resolve(AF_UNSPEC, SOCK_STREAM, AI_ADDRCONFIG, nullptr, ConnectLoopProc); -} - -/** * Helper function to resolve a listening. * @param runp information about the socket to try not * @return the opened socket or INVALID_SOCKET @@ -486,3 +408,16 @@ void NetworkAddress::Listen(int socktype, SocketList *sockets) default: return "unsupported"; } } + +/** + * Get the peer name of a socket in string format. + * @param sock The socket to get the peer name of. + * @return The string representation of the peer name. + */ +/* static */ const std::string NetworkAddress::GetPeerName(SOCKET sock) +{ + sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + getpeername(sock, (sockaddr *)&addr, &addr_len); + return NetworkAddress(addr, addr_len).GetAddressAsString(); +} diff --git a/src/network/core/address.h b/src/network/core/address.h index b4364c95a..c3f95b492 100644 --- a/src/network/core/address.h +++ b/src/network/core/address.h @@ -171,11 +171,11 @@ public: return this->CompareTo(address) < 0; } - SOCKET Connect(); void Listen(int socktype, SocketList *sockets); static const char *SocketTypeAsString(int socktype); static const char *AddressFamilyAsString(int family); + static const std::string GetPeerName(SOCKET sock); }; #endif /* NETWORK_CORE_ADDRESS_H */ diff --git a/src/network/core/tcp.h b/src/network/core/tcp.h index 5acf9d12e..b90ce0232 100644 --- a/src/network/core/tcp.h +++ b/src/network/core/tcp.h @@ -15,6 +15,7 @@ #include "address.h" #include "packet.h" +#include <chrono> #include <atomic> /** The states of sending the packets. */ @@ -63,23 +64,28 @@ public: */ class TCPConnecter { private: - std::atomic<bool> connected;///< Whether we succeeded in making the connection - std::atomic<bool> aborted; ///< Whether we bailed out (i.e. connection making failed) - bool killed; ///< Whether we got killed - SOCKET sock; ///< The socket we're connecting with + addrinfo *ai = nullptr; ///< getaddrinfo() allocated linked-list of resolved addresses. + std::vector<addrinfo *> addresses; ///< Addresses we can connect to. + size_t current_address = 0; ///< Current index in addresses we are trying. - void Connect(); + std::vector<SOCKET> sockets; ///< Pending connect() attempts. + std::chrono::steady_clock::time_point last_attempt; ///< Time we last tried to connect. - static void ThreadEntry(TCPConnecter *param); + std::atomic<bool> is_resolved = false; ///< Whether resolving is done. -protected: - /** Address we're connecting to */ - NetworkAddress address; + void Resolve(); + void OnResolved(addrinfo *ai); + bool TryNextAddress(); + void Connect(addrinfo *address); + bool CheckActivity(); + + static void ResolveThunk(TCPConnecter *connecter); public: + std::string connection_string; ///< Current address we are connecting to (before resolving). + TCPConnecter(const std::string &connection_string, uint16 default_port); - /** Silence the warnings */ - virtual ~TCPConnecter() {} + virtual ~TCPConnecter(); /** * Callback when the connection succeeded. diff --git a/src/network/core/tcp_connect.cpp b/src/network/core/tcp_connect.cpp index 81c4d8c26..cca9f09b7 100644 --- a/src/network/core/tcp_connect.cpp +++ b/src/network/core/tcp_connect.cpp @@ -15,6 +15,8 @@ #include "tcp.h" #include "../network_internal.h" +#include <deque> + #include "../../safeguards.h" /** List of connections that are currently being created */ @@ -24,38 +26,271 @@ static std::vector<TCPConnecter *> _tcp_connecters; * Create a new connecter for the given address * @param connection_string the address to connect to */ -TCPConnecter::TCPConnecter(const std::string &connection_string, uint16 default_port) : - connected(false), - aborted(false), - killed(false), - sock(INVALID_SOCKET) +TCPConnecter::TCPConnecter(const std::string &connection_string, uint16 default_port) { - this->address = ParseConnectionString(connection_string, default_port); + this->connection_string = NormalizeConnectionString(connection_string, default_port); _tcp_connecters.push_back(this); - if (!StartNewThread(nullptr, "ottd:tcp", &TCPConnecter::ThreadEntry, this)) { - this->Connect(); + + if (!StartNewThread(nullptr, "ottd:resolve", &TCPConnecter::ResolveThunk, this)) { + this->Resolve(); + } +} + +TCPConnecter::~TCPConnecter() +{ + for (const auto &socket : this->sockets) { + close(socket); } + + freeaddrinfo(this->ai); } -/** The actual connection function */ -void TCPConnecter::Connect() +/** + * Start a connection to the indicated address. + * @param address The address to connection to. + */ +void TCPConnecter::Connect(addrinfo *address) { - this->sock = this->address.Connect(); - if (this->sock == INVALID_SOCKET) { - this->aborted = true; - } else { - this->connected = true; + SOCKET sock = socket(address->ai_family, address->ai_socktype, address->ai_protocol); + if (sock == INVALID_SOCKET) { + DEBUG(net, 0, "Could not create %s %s socket: %s", NetworkAddress::SocketTypeAsString(address->ai_socktype), NetworkAddress::AddressFamilyAsString(address->ai_family), NetworkError::GetLast().AsString()); + return; + } + + if (!SetNoDelay(sock)) DEBUG(net, 1, "Setting TCP_NODELAY failed"); + if (!SetNonBlocking(sock)) DEBUG(net, 0, "Setting non-blocking mode failed"); + + NetworkAddress network_address = NetworkAddress(address->ai_addr, (int)address->ai_addrlen); + DEBUG(net, 4, "Attempting to connect to %s", network_address.GetAddressAsString().c_str()); + + int err = connect(sock, address->ai_addr, (int)address->ai_addrlen); + if (err != 0 && !NetworkError::GetLast().IsConnectInProgress()) { + closesocket(sock); + + DEBUG(net, 1, "Could not connect to %s: %s", network_address.GetAddressAsString().c_str(), NetworkError::GetLast().AsString()); + return; + } + + this->sockets.push_back(sock); +} + +/** + * Start the connect() for the next address in the list. + * @return True iff a new connect() is attempted. + */ +bool TCPConnecter::TryNextAddress() +{ + if (this->current_address >= this->addresses.size()) return false; + + this->last_attempt = std::chrono::steady_clock::now(); + this->Connect(this->addresses[this->current_address++]); + + return true; +} + +void TCPConnecter::OnResolved(addrinfo *ai) +{ + std::deque<addrinfo *> addresses_ipv4, addresses_ipv6; + + /* Apply "Happy Eyeballs" if it is likely IPv6 is functional. */ + + /* Detect if IPv6 is likely to succeed or not. */ + bool seen_ipv6 = false; + bool resort = true; + for (addrinfo *runp = ai; runp != nullptr; runp = runp->ai_next) { + if (runp->ai_family == AF_INET6) { + seen_ipv6 = true; + } else if (!seen_ipv6) { + /* We see an IPv4 before an IPv6; this most likely means there is + * no IPv6 available on the system, so keep the order of this + * list. */ + resort = false; + break; + } + } + + /* Convert the addrinfo into NetworkAddresses. */ + for (addrinfo *runp = ai; runp != nullptr; runp = runp->ai_next) { + if (resort) { + if (runp->ai_family == AF_INET6) { + addresses_ipv6.emplace_back(runp); + } else { + addresses_ipv4.emplace_back(runp); + } + } else { + this->addresses.emplace_back(runp); + } + } + + /* If we want to resort, make the list like IPv6 / IPv4 / IPv6 / IPv4 / .. + * for how ever many (round-robin) DNS entries we have. */ + if (resort) { + while (!addresses_ipv4.empty() || !addresses_ipv6.empty()) { + if (!addresses_ipv6.empty()) { + this->addresses.push_back(addresses_ipv6.front()); + addresses_ipv6.pop_front(); + } + if (!addresses_ipv4.empty()) { + this->addresses.push_back(addresses_ipv4.front()); + addresses_ipv4.pop_front(); + } + } + } + + if (_debug_net_level >= 5) { + DEBUG(net, 5, "%s resolved in:", this->connection_string.c_str()); + for (const auto &address : this->addresses) { + DEBUG(net, 5, "- %s", NetworkAddress(address->ai_addr, (int)address->ai_addrlen).GetAddressAsString().c_str()); + } + } + + this->current_address = 0; +} + +void TCPConnecter::Resolve() +{ + /* Port is already guaranteed part of the connection_string. */ + NetworkAddress address = ParseConnectionString(this->connection_string, 0); + + addrinfo hints; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_UNSPEC; + hints.ai_flags = AI_ADDRCONFIG; + hints.ai_socktype = SOCK_STREAM; + + char port_name[6]; + seprintf(port_name, lastof(port_name), "%u", address.GetPort()); + + static bool getaddrinfo_timeout_error_shown = false; + auto start = std::chrono::steady_clock::now(); + + addrinfo *ai; + int e = getaddrinfo(address.GetHostname(), port_name, &hints, &ai); + + auto end = std::chrono::steady_clock::now(); + auto duration = std::chrono::duration_cast<std::chrono::seconds>(end - start); + if (!getaddrinfo_timeout_error_shown && duration >= std::chrono::seconds(5)) { + DEBUG(net, 0, "getaddrinfo() for address \"%s\" took %i seconds", this->connection_string.c_str(), (int)duration.count()); + DEBUG(net, 0, " This is likely an issue in the DNS name resolver's configuration causing it to time out"); + getaddrinfo_timeout_error_shown = true; } + + if (e != 0) { + DEBUG(misc, 0, "Failed to resolve DNS for %s", this->connection_string.c_str()); + this->OnFailure(); + return; + } + + this->ai = ai; + this->OnResolved(ai); + this->is_resolved = true; +} + +/* static */ void TCPConnecter::ResolveThunk(TCPConnecter *connecter) +{ + connecter->Resolve(); } /** - * Entry point for the new threads. - * @param param the TCPConnecter instance to call Connect on. + * Check if there was activity for this connecter. + * @return True iff the TCPConnecter is done and can be cleaned up. */ -/* static */ void TCPConnecter::ThreadEntry(TCPConnecter *param) +bool TCPConnecter::CheckActivity() { - param->Connect(); + if (!this->is_resolved.load()) return false; + + /* If there are no attempts pending, connect to the next. */ + if (this->sockets.empty()) { + if (!this->TryNextAddress()) { + /* There were no more addresses to try, so we failed. */ + this->OnFailure(); + return true; + } + return false; + } + + fd_set write_fd; + FD_ZERO(&write_fd); + for (const auto &socket : this->sockets) { + FD_SET(socket, &write_fd); + } + + timeval tv; + tv.tv_usec = 0; + tv.tv_sec = 0; + int n = select(FD_SETSIZE, NULL, &write_fd, NULL, &tv); + /* select() failed; hopefully next try it doesn't. */ + if (n < 0) { + /* select() normally never fails; so hopefully it works next try! */ + DEBUG(net, 1, "select() failed with %s", NetworkError::GetLast().AsString()); + return false; + } + + /* No socket updates. */ + if (n == 0) { + /* Wait 250ms between attempting another address. */ + if (std::chrono::steady_clock::now() < this->last_attempt + std::chrono::milliseconds(250)) return false; + + /* Try the next address in the list. */ + if (this->TryNextAddress()) return false; + + /* Wait up to 3 seconds since the last connection we started. */ + if (std::chrono::steady_clock::now() < this->last_attempt + std::chrono::milliseconds(3000)) return false; + + /* More than 3 seconds no socket reported activity, and there are no + * more address to try. Timeout the attempt. */ + DEBUG(net, 0, "Timeout while connecting to %s", this->connection_string.c_str()); + + for (const auto &socket : this->sockets) { + closesocket(socket); + } + this->OnFailure(); + return true; + } + + /* Check for errors on any of the sockets. */ + for (auto it = this->sockets.begin(); it != this->sockets.end(); /* nothing */) { + NetworkError socket_error = GetSocketError(*it); + if (socket_error.HasError()) { + DEBUG(net, 1, "Could not connect to %s: %s", NetworkAddress::GetPeerName(*it).c_str(), socket_error.AsString()); + closesocket(*it); + it = this->sockets.erase(it); + } else { + it++; + } + } + + /* In case all sockets had an error, queue a new one. */ + if (this->sockets.empty()) { + if (!this->TryNextAddress()) { + /* There were no more addresses to try, so we failed. */ + this->OnFailure(); + return true; + } + return false; + } + + /* At least one socket is connected. The first one that does is the one + * we will be using, and we close all other sockets. */ + SOCKET connected_socket = INVALID_SOCKET; + for (auto it = this->sockets.begin(); it != this->sockets.end(); /* nothing */) { + if (connected_socket == INVALID_SOCKET && FD_ISSET(*it, &write_fd)) { + connected_socket = *it; + } else { + closesocket(*it); + } + it = this->sockets.erase(it); + } + assert(connected_socket != INVALID_SOCKET); + + DEBUG(net, 1, "Connected to %s", this->connection_string.c_str()); + if (_debug_net_level >= 5) { + DEBUG(net, 5, "- using %s", NetworkAddress::GetPeerName(connected_socket).c_str()); + } + + this->OnConnect(connected_socket); + return true; } /** @@ -68,32 +303,22 @@ void TCPConnecter::Connect() { for (auto iter = _tcp_connecters.begin(); iter < _tcp_connecters.end(); /* nothing */) { TCPConnecter *cur = *iter; - const bool connected = cur->connected.load(); - const bool aborted = cur->aborted.load(); - if ((connected || aborted) && cur->killed) { - iter = _tcp_connecters.erase(iter); - if (cur->sock != INVALID_SOCKET) closesocket(cur->sock); - delete cur; - continue; - } - if (connected) { - iter = _tcp_connecters.erase(iter); - cur->OnConnect(cur->sock); - delete cur; - continue; - } - if (aborted) { + + if (cur->CheckActivity()) { iter = _tcp_connecters.erase(iter); - cur->OnFailure(); delete cur; - continue; + } else { + iter++; } - iter++; } } /** Kill all connection attempts. */ /* static */ void TCPConnecter::KillAll() { - for (TCPConnecter *conn : _tcp_connecters) conn->killed = true; + for (auto iter = _tcp_connecters.begin(); iter < _tcp_connecters.end(); /* nothing */) { + TCPConnecter *cur = *iter; + iter = _tcp_connecters.erase(iter); + delete cur; + } } diff --git a/src/network/core/tcp_http.cpp b/src/network/core/tcp_http.cpp index 3b7579fce..ccad120ae 100644 --- a/src/network/core/tcp_http.cpp +++ b/src/network/core/tcp_http.cpp @@ -203,9 +203,11 @@ int NetworkHTTPSocketHandler::HandleHeader() *url = '\0'; + std::string hostname = std::string(hname); + /* Restore the URL. */ *url = '/'; - new NetworkHTTPContentConnecter(hname, callback, url, data, depth); + new NetworkHTTPContentConnecter(hostname, callback, url, data, depth); return 0; } diff --git a/src/network/core/tcp_http.h b/src/network/core/tcp_http.h index cc9a3adac..d7be0c327 100644 --- a/src/network/core/tcp_http.h +++ b/src/network/core/tcp_http.h @@ -73,6 +73,7 @@ public: /** Connect with a HTTP server and do ONE query. */ class NetworkHTTPContentConnecter : TCPConnecter { + std::string hostname; ///< Hostname we are connecting to. HTTPCallback *callback; ///< Callback to tell that we received some data (or won't). const char *url; ///< The URL we want to get at the server. const char *data; ///< The data to send @@ -81,14 +82,15 @@ class NetworkHTTPContentConnecter : TCPConnecter { public: /** * Start the connecting. - * @param connection_string The address to connect to. + * @param hostname The hostname to connect to. * @param callback The callback for HTTP retrieval. * @param url The url at the server. * @param data The data to send. * @param depth The depth (redirect recursion) of the queries. */ - NetworkHTTPContentConnecter(const std::string &connection_string, HTTPCallback *callback, const char *url, const char *data = nullptr, int depth = 0) : - TCPConnecter(connection_string, 80), + NetworkHTTPContentConnecter(const std::string &hostname, HTTPCallback *callback, const char *url, const char *data = nullptr, int depth = 0) : + TCPConnecter(hostname, 80), + hostname(hostname), callback(callback), url(stredup(url)), data(data), @@ -110,7 +112,7 @@ public: void OnConnect(SOCKET s) override { - new NetworkHTTPSocketHandler(s, this->callback, this->address.GetHostname(), this->url, this->data, this->depth); + new NetworkHTTPSocketHandler(s, this->callback, this->hostname.c_str(), this->url, this->data, this->depth); /* We've relinquished control of data now. */ this->data = nullptr; } diff --git a/src/network/network.cpp b/src/network/network.cpp index 83a93a22e..bbf6bb5de 100644 --- a/src/network/network.cpp +++ b/src/network/network.cpp @@ -502,6 +502,19 @@ std::string_view ParseFullConnectionString(const std::string &connection_string, } /** + * Normalize a connection string. That is, ensure there is a port in the string. + * @param connection_string The connection string to normalize. + * @param default_port The port to use if none is given. + * @return The normalized connection string. + */ +std::string NormalizeConnectionString(const std::string &connection_string, uint16 default_port) +{ + uint16 port = default_port; + std::string_view ip = ParseFullConnectionString(connection_string, port); + return std::string(ip) + ":" + std::to_string(port); +} + +/** * Convert a string containing either "hostname" or "hostname:ip" to a * NetworkAddress. * @@ -1131,23 +1144,27 @@ static void NetworkGenerateServerId() seprintf(_settings_client.network.network_id, lastof(_settings_client.network.network_id), "%s", hex_output); } -void NetworkStartDebugLog(const std::string &connection_string) -{ - extern SOCKET _debug_socket; // Comes from debug.c +class TCPNetworkDebugConnecter : TCPConnecter { +public: + TCPNetworkDebugConnecter(const std::string &connection_string) : TCPConnecter(connection_string, NETWORK_DEFAULT_DEBUGLOG_PORT) {} - NetworkAddress address = ParseConnectionString(connection_string, NETWORK_DEFAULT_DEBUGLOG_PORT); + void OnFailure() override + { + DEBUG(net, 0, "Failed to open connection to %s for redirecting DEBUG()", this->connection_string.c_str()); + } - DEBUG(net, 0, "Redirecting DEBUG() to %s", address.GetAddressAsString().c_str()); + void OnConnect(SOCKET s) override + { + DEBUG(net, 0, "Redirecting DEBUG() to %s", this->connection_string.c_str()); - SOCKET s = address.Connect(); - if (s == INVALID_SOCKET) { - DEBUG(net, 0, "Failed to open socket for redirection DEBUG()"); - return; + extern SOCKET _debug_socket; + _debug_socket = s; } +}; - _debug_socket = s; - - DEBUG(net, 0, "DEBUG() is now redirected"); +void NetworkStartDebugLog(const std::string &connection_string) +{ + new TCPNetworkDebugConnecter(connection_string); } /** This tries to launch the network for a given OS */ diff --git a/src/network/network_internal.h b/src/network/network_internal.h index 2a2024bac..88af965cf 100644 --- a/src/network/network_internal.h +++ b/src/network/network_internal.h @@ -120,5 +120,6 @@ bool NetworkFindName(char *new_name, const char *last); const char *GenerateCompanyPasswordHash(const char *password, const char *password_server_id, uint32 password_game_seed); NetworkAddress ParseConnectionString(const std::string &connection_string, uint16 default_port); +std::string NormalizeConnectionString(const std::string &connection_string, uint16 default_port); #endif /* NETWORK_INTERNAL_H */ |