summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/network/core/address.cpp48
-rw-r--r--src/network/core/os_abstraction.h16
2 files changed, 53 insertions, 11 deletions
diff --git a/src/network/core/address.cpp b/src/network/core/address.cpp
index fc19439e0..e53566c0b 100644
--- a/src/network/core/address.cpp
+++ b/src/network/core/address.cpp
@@ -14,6 +14,8 @@
#include "../../safeguards.h"
+static const int DEFAULT_CONNECT_TIMEOUT_SECONDS = 3; ///< Allow connect() three seconds to connect.
+
/**
* Get the hostname; in case it wasn't given the
* IPv4 dotted representation is given.
@@ -322,23 +324,47 @@ static SOCKET ConnectLoopProc(addrinfo *runp)
if (!SetNoDelay(sock)) DEBUG(net, 1, "[%s] setting TCP_NODELAY failed", type);
+ if (!SetNonBlocking(sock)) DEBUG(net, 0, "[%s] setting non-blocking mode failed", type);
+
int err = connect(sock, runp->ai_addr, (int)runp->ai_addrlen);
-#ifdef __EMSCRIPTEN__
- /* Emscripten is asynchronous, and as such a connect() is still in
- * progress by the time the call returns. */
- if (err != 0 && errno != EINPROGRESS)
-#else
- if (err != 0)
-#endif
- {
- DEBUG(net, 1, "[%s] could not connect %s socket: %s", type, family, NetworkGetLastErrorString());
+ if (err != 0 && NetworkGetLastError() != EINPROGRESS) {
+ DEBUG(net, 1, "[%s] could not connect to %s over %s: %s", type, address, family, NetworkGetLastErrorString());
closesocket(sock);
return INVALID_SOCKET;
}
- /* Connection succeeded */
- if (!SetNonBlocking(sock)) DEBUG(net, 0, "[%s] setting non-blocking mode failed", type);
+ fd_set write_fd;
+ struct timeval tv;
+
+ FD_ZERO(&write_fd);
+ FD_SET(sock, &write_fd);
+
+ /* Wait for connect() to either connect, timeout or fail. */
+ tv.tv_usec = 0;
+ tv.tv_sec = DEFAULT_CONNECT_TIMEOUT_SECONDS;
+ int n = select(FD_SETSIZE, NULL, &write_fd, NULL, &tv);
+ if (n < 0) {
+ DEBUG(net, 1, "[%s] could not connect to %s: %s", type, address, NetworkGetLastErrorString());
+ closesocket(sock);
+ return INVALID_SOCKET;
+ }
+
+ /* If no fd is selected, the timeout has been reached. */
+ if (n == 0) {
+ DEBUG(net, 1, "[%s] timed out while connecting to %s", type, address);
+ closesocket(sock);
+ return INVALID_SOCKET;
+ }
+
+ /* Retrieve last error, if any, on the socket. */
+ err = GetSocketError(sock);
+ if (err != 0) {
+ DEBUG(net, 1, "[%s] could not connect to %s: %s", type, address, NetworkGetErrorString(err));
+ closesocket(sock);
+ return INVALID_SOCKET;
+ }
+ /* Connection succeeded. */
DEBUG(net, 1, "[%s] connected to %s", type, address);
return sock;
diff --git a/src/network/core/os_abstraction.h b/src/network/core/os_abstraction.h
index 7af3fd163..9bd0e321f 100644
--- a/src/network/core/os_abstraction.h
+++ b/src/network/core/os_abstraction.h
@@ -33,6 +33,8 @@
#define EWOULDBLOCK WSAEWOULDBLOCK
#undef ECONNRESET
#define ECONNRESET WSAECONNRESET
+#undef EINPROGRESS
+#define EINPROGRESS WSAEWOULDBLOCK
const char *NetworkGetErrorString(int error);
@@ -230,6 +232,20 @@ static inline bool SetNoDelay(SOCKET d)
#endif
}
+/**
+ * Get the error from a socket, if any.
+ * @param d The socket to get the error from.
+ * @return The errno on the socket.
+ */
+static inline int GetSocketError(SOCKET d)
+{
+ int err;
+ socklen_t len = sizeof(err);
+ getsockopt(d, SOL_SOCKET, SO_ERROR, (char *)&err, &len);
+
+ return err;
+}
+
/* Make sure these structures have the size we expect them to be */
static_assert(sizeof(in_addr) == 4); ///< IPv4 addresses should be 4 bytes.
static_assert(sizeof(in6_addr) == 16); ///< IPv6 addresses should be 16 bytes.