summaryrefslogtreecommitdiff
path: root/src/network
diff options
context:
space:
mode:
authorrubidium42 <rubidium@openttd.org>2021-04-30 15:38:22 +0200
committerrubidium42 <rubidium42@users.noreply.github.com>2021-05-01 19:36:22 +0200
commit22720332eb9922e20148c7aae1127f7304f6f7d3 (patch)
tree0efbcfac24474a10f88f956c1f4721cdd7af97e0 /src/network
parent0eb17a70af86d11e49d9560088900c9d65cb07c1 (diff)
downloadopenttd-22720332eb9922e20148c7aae1127f7304f6f7d3.tar.xz
Codechange: encapsulate network error handling
Diffstat (limited to 'src/network')
-rw-r--r--src/network/core/CMakeLists.txt1
-rw-r--r--src/network/core/address.cpp24
-rw-r--r--src/network/core/core.cpp18
-rw-r--r--src/network/core/os_abstraction.cpp125
-rw-r--r--src/network/core/os_abstraction.h52
-rw-r--r--src/network/core/tcp.cpp22
-rw-r--r--src/network/core/tcp_http.cpp8
-rw-r--r--src/network/core/tcp_listen.h4
-rw-r--r--src/network/core/udp.cpp4
9 files changed, 179 insertions, 79 deletions
diff --git a/src/network/core/CMakeLists.txt b/src/network/core/CMakeLists.txt
index 37cc3e195..bf713be99 100644
--- a/src/network/core/CMakeLists.txt
+++ b/src/network/core/CMakeLists.txt
@@ -8,6 +8,7 @@ add_files(
game_info.h
host.cpp
host.h
+ os_abstraction.cpp
os_abstraction.h
packet.cpp
packet.h
diff --git a/src/network/core/address.cpp b/src/network/core/address.cpp
index e91751c33..d3e373ef7 100644
--- a/src/network/core/address.cpp
+++ b/src/network/core/address.cpp
@@ -317,7 +317,7 @@ static SOCKET ConnectLoopProc(addrinfo *runp)
SOCKET sock = socket(runp->ai_family, runp->ai_socktype, runp->ai_protocol);
if (sock == INVALID_SOCKET) {
- DEBUG(net, 1, "[%s] could not create %s socket: %s", type, family, NetworkGetLastErrorString());
+ DEBUG(net, 1, "[%s] could not create %s socket: %s", type, family, NetworkError::GetLast().AsString());
return INVALID_SOCKET;
}
@@ -326,8 +326,8 @@ static SOCKET ConnectLoopProc(addrinfo *runp)
if (!SetNonBlocking(sock)) DEBUG(net, 0, "[%s] setting non-blocking mode failed", type);
int err = connect(sock, runp->ai_addr, (int)runp->ai_addrlen);
- if (err != 0 && NetworkGetLastError() != EINPROGRESS) {
- DEBUG(net, 1, "[%s] could not connect to %s over %s: %s", type, address.c_str(), family, NetworkGetLastErrorString());
+ if (err != 0 && !NetworkError::GetLast().IsConnectInProgress()) {
+ DEBUG(net, 1, "[%s] could not connect to %s over %s: %s", type, address.c_str(), family, NetworkError::GetLast().AsString());
closesocket(sock);
return INVALID_SOCKET;
}
@@ -343,7 +343,7 @@ static SOCKET ConnectLoopProc(addrinfo *runp)
tv.tv_sec = DEFAULT_CONNECT_TIMEOUT_SECONDS;
int n = select(FD_SETSIZE, NULL, &write_fd, NULL, &tv);
if (n < 0) {
- DEBUG(net, 1, "[%s] could not connect to %s: %s", type, address.c_str(), NetworkGetLastErrorString());
+ DEBUG(net, 1, "[%s] could not connect to %s: %s", type, address.c_str(), NetworkError::GetLast().AsString());
closesocket(sock);
return INVALID_SOCKET;
}
@@ -356,9 +356,9 @@ static SOCKET ConnectLoopProc(addrinfo *runp)
}
/* Retrieve last error, if any, on the socket. */
- err = GetSocketError(sock);
- if (err != 0) {
- DEBUG(net, 1, "[%s] could not connect to %s: %s", type, address.c_str(), NetworkGetErrorString(err));
+ NetworkError socket_error = GetSocketError(sock);
+ if (socket_error.HasError()) {
+ DEBUG(net, 1, "[%s] could not connect to %s: %s", type, address.c_str(), socket_error.AsString());
closesocket(sock);
return INVALID_SOCKET;
}
@@ -393,7 +393,7 @@ static SOCKET ListenLoopProc(addrinfo *runp)
SOCKET sock = socket(runp->ai_family, runp->ai_socktype, runp->ai_protocol);
if (sock == INVALID_SOCKET) {
- DEBUG(net, 0, "[%s] could not create %s socket on port %s: %s", type, family, address.c_str(), NetworkGetLastErrorString());
+ DEBUG(net, 0, "[%s] could not create %s socket on port %s: %s", type, family, address.c_str(), NetworkError::GetLast().AsString());
return INVALID_SOCKET;
}
@@ -404,24 +404,24 @@ static SOCKET ListenLoopProc(addrinfo *runp)
int on = 1;
/* The (const char*) cast is needed for windows!! */
if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (const char*)&on, sizeof(on)) == -1) {
- DEBUG(net, 3, "[%s] could not set reusable %s sockets for port %s: %s", type, family, address.c_str(), NetworkGetLastErrorString());
+ DEBUG(net, 3, "[%s] could not set reusable %s sockets for port %s: %s", type, family, address.c_str(), NetworkError::GetLast().AsString());
}
#ifndef __OS2__
if (runp->ai_family == AF_INET6 &&
setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, (const char*)&on, sizeof(on)) == -1) {
- DEBUG(net, 3, "[%s] could not disable IPv4 over IPv6 on port %s: %s", type, address.c_str(), NetworkGetLastErrorString());
+ DEBUG(net, 3, "[%s] could not disable IPv4 over IPv6 on port %s: %s", type, address.c_str(), NetworkError::GetLast().AsString());
}
#endif
if (bind(sock, runp->ai_addr, (int)runp->ai_addrlen) != 0) {
- DEBUG(net, 1, "[%s] could not bind on %s port %s: %s", type, family, address.c_str(), NetworkGetLastErrorString());
+ DEBUG(net, 1, "[%s] could not bind on %s port %s: %s", type, family, address.c_str(), NetworkError::GetLast().AsString());
closesocket(sock);
return INVALID_SOCKET;
}
if (runp->ai_socktype != SOCK_DGRAM && listen(sock, 1) != 0) {
- DEBUG(net, 1, "[%s] could not listen at %s port %s: %s", type, family, address.c_str(), NetworkGetLastErrorString());
+ DEBUG(net, 1, "[%s] could not listen at %s port %s: %s", type, family, address.c_str(), NetworkError::GetLast().AsString());
closesocket(sock);
return INVALID_SOCKET;
}
diff --git a/src/network/core/core.cpp b/src/network/core/core.cpp
index 5c12cb224..563deae96 100644
--- a/src/network/core/core.cpp
+++ b/src/network/core/core.cpp
@@ -13,7 +13,6 @@
#include "../../debug.h"
#include "os_abstraction.h"
#include "packet.h"
-#include "../../string_func.h"
#include "../../safeguards.h"
@@ -48,20 +47,3 @@ void NetworkCoreShutdown()
WSACleanup();
#endif
}
-
-#if defined(_WIN32)
-/**
- * Return the string representation of the given error from the OS's network functions.
- * @param error The error number (from \c NetworkGetLastError()).
- * @return The error message, potentially an empty string but never \c nullptr.
- */
-const char *NetworkGetErrorString(int error)
-{
- static char buffer[512];
- if (FormatMessageA(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, error,
- MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), buffer, sizeof(buffer), NULL) == 0) {
- seprintf(buffer, lastof(buffer), "Unknown error %d", error);
- }
- return buffer;
-}
-#endif /* defined(_WIN32) */
diff --git a/src/network/core/os_abstraction.cpp b/src/network/core/os_abstraction.cpp
new file mode 100644
index 000000000..75f2224eb
--- /dev/null
+++ b/src/network/core/os_abstraction.cpp
@@ -0,0 +1,125 @@
+/*
+ * This file is part of OpenTTD.
+ * OpenTTD is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, version 2.
+ * OpenTTD is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
+ * See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with OpenTTD. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+/**
+ * @file os_abstraction.cpp OS specific implementations of functions of the OS abstraction layer for network stuff.
+ *
+ * The general idea is to have simple abstracting functions for things that
+ * require different implementations for different environments.
+ * In here the functions, and their documentation, are defined only once
+ * and the implementation contains the #ifdefs to change the implementation.
+ * Since Windows is usually different that is usually the first case, after
+ * that the behaviour is usually Unix/BSD-like with occasional variation.
+ */
+
+#include "stdafx.h"
+#include "os_abstraction.h"
+#include "../../string_func.h"
+#include <mutex>
+
+#include "../../safeguards.h"
+
+/**
+ * Construct the network error with the given error code.
+ * @param error The error code.
+ */
+NetworkError::NetworkError(int error) : error(error)
+{
+}
+
+/**
+ * Check whether this error describes that the operation would block.
+ * @return True iff the operation would block.
+ */
+bool NetworkError::WouldBlock() const
+{
+#if defined(_WIN32)
+ return this->error == WSAEWOULDBLOCK;
+#else
+ /* Usually EWOULDBLOCK and EAGAIN are the same, but sometimes they are not
+ * and the POSIX.1 specification states that either should be checked. */
+ return this->error == EWOULDBLOCK || this->error == EAGAIN;
+#endif
+}
+
+/**
+ * Check whether this error describes a connection reset.
+ * @return True iff the connection is reset.
+ */
+bool NetworkError::IsConnectionReset() const
+{
+#if defined(_WIN32)
+ return this->error == WSAECONNRESET;
+#else
+ return this->error == ECONNRESET;
+#endif
+}
+
+/**
+ * Check whether this error describes a connect is in progress.
+ * @return True iff the connect is already in progress.
+ */
+bool NetworkError::IsConnectInProgress() const
+{
+#if defined(_WIN32)
+ return this->error == WSAEWOULDBLOCK;
+#else
+ return this->error == EINPROGRESS;
+#endif
+}
+
+/**
+ * Get the string representation of the error message.
+ * @return The string representation that will get overwritten by next calls.
+ */
+const char *NetworkError::AsString() const
+{
+ if (this->message.empty()) {
+#if defined(_WIN32)
+ char buffer[512];
+ if (FormatMessageA(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, this->error,
+ MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), buffer, sizeof(buffer), NULL) == 0) {
+ seprintf(buffer, lastof(buffer), "Unknown error %d", this->error);
+ }
+ this->message.assign(buffer);
+#else
+ /* Make strerror thread safe by locking access to it. There is a thread safe strerror_r, however
+ * the non-POSIX variant is available due to defining _GNU_SOURCE meaning it is not portable.
+ * The problem with the non-POSIX variant is that it does not necessarily fill the buffer with
+ * the error message but can also return a pointer to a static bit of memory, whereas the POSIX
+ * variant always fills the buffer. This makes the behaviour too erratic to work with. */
+ static std::mutex mutex;
+ std::lock_guard<std::mutex> guard(mutex);
+ this->message.assign(strerror(this->error));
+#endif
+ }
+ return this->message.c_str();
+}
+
+/**
+ * Check whether an error was actually set.
+ * @return True iff an error was set.
+ */
+bool NetworkError::HasError() const
+{
+ return this->error != 0;
+}
+
+/**
+ * Get the last network error.
+ * @return The network error.
+ */
+/* static */ NetworkError NetworkError::GetLast()
+{
+#if defined(_WIN32)
+ return NetworkError(WSAGetLastError());
+#elif defined(__OS2__)
+ return NetworkError(sock_errno());
+#else
+ return NetworkError(errno);
+#endif
+}
diff --git a/src/network/core/os_abstraction.h b/src/network/core/os_abstraction.h
index 9bd0e321f..e444bc78b 100644
--- a/src/network/core/os_abstraction.h
+++ b/src/network/core/os_abstraction.h
@@ -14,6 +14,26 @@
#ifndef NETWORK_CORE_OS_ABSTRACTION_H
#define NETWORK_CORE_OS_ABSTRACTION_H
+/**
+ * Abstraction of a network error where all implementation details of the
+ * error codes are encapsulated in this class and the abstraction layer.
+ */
+class NetworkError {
+private:
+ int error; ///< The underlying error number from errno or WSAGetLastError.
+ mutable std::string message; ///< The string representation of the error (set on first call to #AsString).
+public:
+ NetworkError(int error);
+
+ bool HasError() const;
+ bool WouldBlock() const;
+ bool IsConnectionReset() const;
+ bool IsConnectInProgress() const;
+ const char *AsString() const;
+
+ static NetworkError GetLast();
+};
+
/* Include standard stuff per OS */
/* Windows stuff */
@@ -23,21 +43,6 @@
#include <ws2tcpip.h>
#include <windows.h>
-/**
- * Get the last error code from any of the OS's network functions.
- * What it returns and when it is reset, is implementation defined.
- * @return The last error code.
- */
-#define NetworkGetLastError() WSAGetLastError()
-#undef EWOULDBLOCK
-#define EWOULDBLOCK WSAEWOULDBLOCK
-#undef ECONNRESET
-#define ECONNRESET WSAECONNRESET
-#undef EINPROGRESS
-#define EINPROGRESS WSAEWOULDBLOCK
-
-const char *NetworkGetErrorString(int error);
-
/* Windows has some different names for some types */
typedef unsigned long in_addr_t;
@@ -63,8 +68,6 @@ typedef unsigned long in_addr_t;
# define INVALID_SOCKET -1
# define ioctlsocket ioctl
# define closesocket close
-# define NetworkGetLastError() (errno)
-# define NetworkGetErrorString(error) (strerror(error))
/* Need this for FIONREAD on solaris */
# define BSD_COMP
@@ -114,8 +117,6 @@ typedef unsigned long in_addr_t;
# define INVALID_SOCKET -1
# define ioctlsocket ioctl
# define closesocket close
-# define NetworkGetLastError() (sock_errno())
-# define NetworkGetErrorString(error) (strerror(error))
/* Includes needed for OS/2 systems */
# include <types.h>
@@ -188,15 +189,6 @@ static inline socklen_t FixAddrLenForEmscripten(struct sockaddr_storage &address
#endif
/**
- * Return the string representation of the last error from the OS's network functions.
- * @return The error message, potentially an empty string but never \c nullptr.
- */
-static inline const char *NetworkGetLastErrorString()
-{
- return NetworkGetErrorString(NetworkGetLastError());
-}
-
-/**
* Try to set the socket into non-blocking mode.
* @param d The socket to set the non-blocking more for.
* @return True if setting the non-blocking mode succeeded, otherwise false.
@@ -237,13 +229,13 @@ static inline bool SetNoDelay(SOCKET d)
* @param d The socket to get the error from.
* @return The errno on the socket.
*/
-static inline int GetSocketError(SOCKET d)
+static inline NetworkError GetSocketError(SOCKET d)
{
int err;
socklen_t len = sizeof(err);
getsockopt(d, SOL_SOCKET, SO_ERROR, (char *)&err, &len);
- return err;
+ return NetworkError(err);
}
/* Make sure these structures have the size we expect them to be */
diff --git a/src/network/core/tcp.cpp b/src/network/core/tcp.cpp
index f23b202c8..842e1a89b 100644
--- a/src/network/core/tcp.cpp
+++ b/src/network/core/tcp.cpp
@@ -86,11 +86,11 @@ SendPacketsState NetworkTCPSocketHandler::SendPackets(bool closing_down)
while ((p = this->packet_queue) != nullptr) {
res = p->TransferOut<int>(send, this->sock, 0);
if (res == -1) {
- int err = NetworkGetLastError();
- if (err != EWOULDBLOCK) {
+ NetworkError err = NetworkError::GetLast();
+ if (!err.WouldBlock()) {
/* Something went wrong.. close client! */
if (!closing_down) {
- DEBUG(net, 0, "send failed with error %s", NetworkGetErrorString(err));
+ DEBUG(net, 0, "send failed with error %s", err.AsString());
this->CloseConnection();
}
return SPS_CLOSED;
@@ -136,10 +136,10 @@ Packet *NetworkTCPSocketHandler::ReceivePacket()
while (p->RemainingBytesToTransfer() != 0) {
res = p->TransferIn<int>(recv, this->sock, 0);
if (res == -1) {
- int err = NetworkGetLastError();
- if (err != EWOULDBLOCK) {
- /* Something went wrong... (ECONNRESET is connection reset by peer) */
- if (err != ECONNRESET) DEBUG(net, 0, "recv failed with error %s", NetworkGetErrorString(err));
+ NetworkError err = NetworkError::GetLast();
+ if (!err.WouldBlock()) {
+ /* Something went wrong... */
+ if (!err.IsConnectionReset()) DEBUG(net, 0, "recv failed with error %s", err.AsString());
this->CloseConnection();
return nullptr;
}
@@ -164,10 +164,10 @@ Packet *NetworkTCPSocketHandler::ReceivePacket()
while (p->RemainingBytesToTransfer() != 0) {
res = p->TransferIn<int>(recv, this->sock, 0);
if (res == -1) {
- int err = NetworkGetLastError();
- if (err != EWOULDBLOCK) {
- /* Something went wrong... (ECONNRESET is connection reset by peer) */
- if (err != ECONNRESET) DEBUG(net, 0, "recv failed with error %s", NetworkGetErrorString(err));
+ NetworkError err = NetworkError::GetLast();
+ if (!err.WouldBlock()) {
+ /* Something went wrong... */
+ if (!err.IsConnectionReset()) DEBUG(net, 0, "recv failed with error %s", err.AsString());
this->CloseConnection();
return nullptr;
}
diff --git a/src/network/core/tcp_http.cpp b/src/network/core/tcp_http.cpp
index e0c269faf..4f29df191 100644
--- a/src/network/core/tcp_http.cpp
+++ b/src/network/core/tcp_http.cpp
@@ -225,10 +225,10 @@ int NetworkHTTPSocketHandler::Receive()
for (;;) {
ssize_t res = recv(this->sock, (char *)this->recv_buffer + this->recv_pos, lengthof(this->recv_buffer) - this->recv_pos, 0);
if (res == -1) {
- int err = NetworkGetLastError();
- if (err != EWOULDBLOCK) {
- /* Something went wrong... (ECONNRESET is connection reset by peer) */
- if (err != ECONNRESET) DEBUG(net, 0, "recv failed with error %s", NetworkGetErrorString(err));
+ NetworkError err = NetworkError::GetLast();
+ if (!err.WouldBlock()) {
+ /* Something went wrong... */
+ if (!err.IsConnectionReset()) DEBUG(net, 0, "recv failed with error %s", err.AsString());
return -1;
}
/* Connection would block, so stop for now */
diff --git a/src/network/core/tcp_listen.h b/src/network/core/tcp_listen.h
index 2ceea20aa..e23ecae70 100644
--- a/src/network/core/tcp_listen.h
+++ b/src/network/core/tcp_listen.h
@@ -64,7 +64,7 @@ public:
DEBUG(net, 1, "[%s] Banned ip tried to join (%s), refused", Tsocket::GetName(), entry.c_str());
if (p.TransferOut<int>(send, s, 0) < 0) {
- DEBUG(net, 0, "send failed with error %s", NetworkGetLastErrorString());
+ DEBUG(net, 0, "send failed with error %s", NetworkError::GetLast().AsString());
}
closesocket(s);
break;
@@ -81,7 +81,7 @@ public:
p.PrepareToSend();
if (p.TransferOut<int>(send, s, 0) < 0) {
- DEBUG(net, 0, "send failed with error %s", NetworkGetLastErrorString());
+ DEBUG(net, 0, "send failed with error %s", NetworkError::GetLast().AsString());
}
closesocket(s);
diff --git a/src/network/core/udp.cpp b/src/network/core/udp.cpp
index ffc86d825..e7b99a53e 100644
--- a/src/network/core/udp.cpp
+++ b/src/network/core/udp.cpp
@@ -95,7 +95,7 @@ void NetworkUDPSocketHandler::SendPacket(Packet *p, NetworkAddress *recv, bool a
/* Enable broadcast */
unsigned long val = 1;
if (setsockopt(s.second, SOL_SOCKET, SO_BROADCAST, (char *) &val, sizeof(val)) < 0) {
- DEBUG(net, 1, "[udp] setting broadcast failed with: %s", NetworkGetLastErrorString());
+ DEBUG(net, 1, "[udp] setting broadcast failed with: %s", NetworkError::GetLast().AsString());
}
}
@@ -104,7 +104,7 @@ void NetworkUDPSocketHandler::SendPacket(Packet *p, NetworkAddress *recv, bool a
DEBUG(net, 7, "[udp] sendto(%s)", send.GetAddressAsString().c_str());
/* Check for any errors, but ignore it otherwise */
- if (res == -1) DEBUG(net, 1, "[udp] sendto(%s) failed with: %s", send.GetAddressAsString().c_str(), NetworkGetLastErrorString());
+ if (res == -1) DEBUG(net, 1, "[udp] sendto(%s) failed with: %s", send.GetAddressAsString().c_str(), NetworkError::GetLast().AsString());
if (!all) break;
}