summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/network/core/address.cpp31
-rw-r--r--src/network/core/address.h11
2 files changed, 28 insertions, 14 deletions
diff --git a/src/network/core/address.cpp b/src/network/core/address.cpp
index 0db0705d9..5d889c0ad 100644
--- a/src/network/core/address.cpp
+++ b/src/network/core/address.cpp
@@ -87,7 +87,7 @@ const sockaddr_storage *NetworkAddress::GetAddress()
* bothered to implement the specifications and allow '0' as value
* that means "don't care whether it is SOCK_STREAM or SOCK_DGRAM".
*/
- this->Resolve(this->address.ss_family, SOCK_STREAM, AI_ADDRCONFIG, ResolveLoopProc);
+ this->Resolve(this->address.ss_family, SOCK_STREAM, AI_ADDRCONFIG, NULL, ResolveLoopProc);
}
return &this->address;
}
@@ -146,7 +146,7 @@ bool NetworkAddress::IsInNetmask(char *netmask)
return true;
}
-SOCKET NetworkAddress::Resolve(int family, int socktype, int flags, LoopProc func)
+SOCKET NetworkAddress::Resolve(int family, int socktype, int flags, SocketList *sockets, LoopProc func)
{
struct addrinfo *ai;
struct addrinfo hints;
@@ -159,6 +159,9 @@ SOCKET NetworkAddress::Resolve(int family, int socktype, int flags, LoopProc fun
char port_name[6];
seprintf(port_name, lastof(port_name), "%u", this->GetPort());
+ /* Setting both hostname to NULL and port to 0 is not allowed.
+ * As port 0 means bind to any port, the other must mean that
+ * we want to bind to 'all' IPs. */
if (this->address_length == 0 && StrEmpty(this->hostname)) {
strecpy(this->hostname, this->address.ss_family == AF_INET ? "0.0.0.0" : "::", lastof(this->hostname));
}
@@ -174,10 +177,16 @@ SOCKET NetworkAddress::Resolve(int family, int socktype, int flags, LoopProc fun
sock = func(runp);
if (sock == INVALID_SOCKET) continue;
- this->address_length = runp->ai_addrlen;
- assert(sizeof(this->address) >= runp->ai_addrlen);
- memcpy(&this->address, runp->ai_addr, runp->ai_addrlen);
- break;
+ if (sockets == NULL) {
+ this->address_length = runp->ai_addrlen;
+ assert(sizeof(this->address) >= runp->ai_addrlen);
+ memcpy(&this->address, runp->ai_addr, runp->ai_addrlen);
+ break;
+ }
+
+ NetworkAddress addr(runp->ai_addr, runp->ai_addrlen);
+ (*sockets)[addr] = sock;
+ sock = INVALID_SOCKET;
}
freeaddrinfo (ai);
@@ -215,7 +224,7 @@ SOCKET NetworkAddress::Connect()
{
DEBUG(net, 1, "Connecting to %s", this->GetAddressAsString());
- return this->Resolve(0, SOCK_STREAM, AI_ADDRCONFIG, ConnectLoopProc);
+ return this->Resolve(0, SOCK_STREAM, AI_ADDRCONFIG, NULL, ConnectLoopProc);
}
/**
@@ -231,7 +240,9 @@ static SOCKET ListenLoopProc(addrinfo *runp)
return INVALID_SOCKET;
}
- if (!SetNoDelay(sock)) DEBUG(net, 1, "Setting TCP_NODELAY failed");
+ if (runp->ai_socktype == SOCK_STREAM && !SetNoDelay(sock)) {
+ DEBUG(net, 1, "Setting TCP_NODELAY failed");
+ }
int on = 1;
/* The (const char*) cast is needed for windows!! */
@@ -262,9 +273,9 @@ static SOCKET ListenLoopProc(addrinfo *runp)
return sock;
}
-SOCKET NetworkAddress::Listen(int family, int socktype)
+SOCKET NetworkAddress::Listen(int family, int socktype, SocketList *sockets)
{
- return this->Resolve(family, socktype, AI_ADDRCONFIG | AI_PASSIVE, ListenLoopProc);
+ return this->Resolve(family, socktype, AI_ADDRCONFIG | AI_PASSIVE, sockets, ListenLoopProc);
}
#endif /* ENABLE_NETWORK */
diff --git a/src/network/core/address.h b/src/network/core/address.h
index d19e7b376..4d263b3c0 100644
--- a/src/network/core/address.h
+++ b/src/network/core/address.h
@@ -10,10 +10,11 @@
#include "os_abstraction.h"
#include "config.h"
#include "../../string_func.h"
-#include "../../core/smallvec_type.hpp"
+#include "../../core/smallmap_type.hpp"
class NetworkAddress;
typedef SmallVector<NetworkAddress, 4> NetworkAddressList;
+typedef SmallMap<NetworkAddress, SOCKET, 4> SocketList;
/**
* Wrapper for (un)resolved network addresses; there's no reason to transform
@@ -38,10 +39,11 @@ private:
* @param family the type of 'protocol' (IPv4, IPv6)
* @param socktype the type of socket (TCP, UDP, etc)
* @param flags the flags to send to getaddrinfo
+ * @param sockets the list of sockets to add the sockets to
* @param func the inner working while looping over the address info
* @return the resolved socket or INVALID_SOCKET.
*/
- SOCKET Resolve(int family, int socktype, int flags, LoopProc func);
+ SOCKET Resolve(int family, int socktype, int flags, SocketList *sockets, LoopProc func);
public:
/**
* Create a network address based on a resolved IP and port
@@ -217,9 +219,10 @@ public:
* Make the given socket listen.
* @param family the type of 'protocol' (IPv4, IPv6)
* @param socktype the type of socket (TCP, UDP, etc)
- * @return the listening socket or INVALID_SOCKET.
+ * @param sockets the list of sockets to add the sockets to
+ * @return the socket (if sockets != NULL)
*/
- SOCKET Listen(int family, int socktype);
+ SOCKET Listen(int family, int socktype, SocketList *sockets = NULL);
};
#endif /* ENABLE_NETWORK */