diff options
-rw-r--r-- | src/network/core/address.cpp | 56 | ||||
-rw-r--r-- | src/network/core/address.h | 8 | ||||
-rw-r--r-- | src/network/network.cpp | 29 |
3 files changed, 68 insertions, 25 deletions
diff --git a/src/network/core/address.cpp b/src/network/core/address.cpp index 316a8b98a..a8ba5c076 100644 --- a/src/network/core/address.cpp +++ b/src/network/core/address.cpp @@ -89,6 +89,60 @@ const sockaddr_storage *NetworkAddress::GetAddress() return &this->address; } +bool NetworkAddress::IsInNetmask(char *netmask) +{ + /* Resolve it if we didn't do it already */ + if (!this->IsResolved()) this->GetAddress(); + + int cidr = this->address.ss_family == AF_INET ? 32 : 128; + + NetworkAddress mask_address; + + /* Check for CIDR separator */ + char *chr_cidr = strchr(netmask, '/'); + if (chr_cidr != NULL) { + int tmp_cidr = atoi(chr_cidr + 1); + + /* Invalid CIDR, treat as single host */ + if (tmp_cidr > 0 || tmp_cidr < cidr) cidr = tmp_cidr; + + /* Remove and then replace the / so that NetworkAddress works on the IP portion */ + *chr_cidr = '\0'; + mask_address = NetworkAddress(netmask, 0, this->address.ss_family); + *chr_cidr = '/'; + } else { + mask_address = NetworkAddress(netmask, 0, this->address.ss_family); + } + + if (mask_address.GetAddressLength() == 0) return false; + + uint32 *ip; + uint32 *mask; + switch (this->address.ss_family) { + case AF_INET: + ip = &((struct sockaddr_in*)&this->address)->sin_addr.s_addr; + mask = &((struct sockaddr_in*)&mask_address.address)->sin_addr.s_addr; + break; + + case AF_INET6: + ip = ((struct sockaddr_in6*)&this->address)->sin6_addr.s6_addr32; + mask = ((struct sockaddr_in6*)&mask_address.address)->sin6_addr.s6_addr32; + break; + + default: + NOT_REACHED(); + } + + while (cidr > 0) { + uint32 msk = cidr >= 32 ? -1 : htonl(-(1 << (32 - cidr))); + if ((*mask & msk) != (*ip & msk)) return false; + + cidr -= 32; + } + + return true; +} + SOCKET NetworkAddress::Resolve(int family, int socktype, int flags, LoopProc func) { struct addrinfo *ai; @@ -104,7 +158,7 @@ SOCKET NetworkAddress::Resolve(int family, int socktype, int flags, LoopProc fun int e = getaddrinfo(this->GetHostname(), port_name, &hints, &ai); if (e != 0) { - DEBUG(net, 0, "getaddrinfo failed: %s", FS2OTTD(gai_strerror(e))); + DEBUG(net, 0, "getaddrinfo(%s, %s) failed: %s", this->GetHostname(), port_name, FS2OTTD(gai_strerror(e))); return INVALID_SOCKET; } diff --git a/src/network/core/address.h b/src/network/core/address.h index 2b5f7444f..f972a0397 100644 --- a/src/network/core/address.h +++ b/src/network/core/address.h @@ -159,6 +159,14 @@ public: } /** + * Checks whether this IP address is contained by the given netmask. + * @param netmask the netmask in CIDR notation to test against. + * @note netmask without /n assumes all bits need to match. + * @return true if this IP is within the netmask. + */ + bool IsInNetmask(char *netmask); + + /** * Compare the address of this class with the address of another. * @param address the other address. */ diff --git a/src/network/network.cpp b/src/network/network.cpp index 09e4e7a40..9dee049f5 100644 --- a/src/network/network.cpp +++ b/src/network/network.cpp @@ -470,7 +470,6 @@ void NetworkCloseClient(NetworkClientSocket *cs) /* For the server, to accept new clients */ static void NetworkAcceptClients() { - struct sockaddr_in sin; NetworkClientSocket *cs; bool banned; @@ -478,6 +477,7 @@ static void NetworkAcceptClients() assert(_listensocket != INVALID_SOCKET); for (;;) { + struct sockaddr_storage sin; memset(&sin, 0, sizeof(sin)); socklen_t sin_len = sizeof(sin); SOCKET s = accept(_listensocket, (struct sockaddr*)&sin, &sin_len); @@ -485,34 +485,15 @@ static void NetworkAcceptClients() SetNonBlocking(s); // XXX error handling? - DEBUG(net, 1, "Client connected from %s on frame %d", inet_ntoa(sin.sin_addr), _frame_counter); + NetworkAddress address(sin, sin_len); + DEBUG(net, 1, "Client connected from %s on frame %d", address.GetHostname(), _frame_counter); SetNoDelay(s); // XXX error handling? /* Check if the client is banned */ banned = false; for (char **iter = _network_ban_list.Begin(); iter != _network_ban_list.End(); iter++) { - /* Check for CIDR separator */ - char *chr_cidr = strchr(*iter, '/'); - if (chr_cidr != NULL) { - int cidr = atoi(chr_cidr + 1); - - /* Invalid CIDR, treat as single host */ - if (cidr <= 0 || cidr > 32) cidr = 32; - - /* Remove and then replace the / so that inet_addr() works on the IP portion */ - *chr_cidr = '\0'; - uint32 ban_ip = inet_addr(*iter); - *chr_cidr = '/'; - - /* Convert CIDR to mask in network format */ - uint32 mask = htonl(-(1 << (32 - cidr))); - if ((sin.sin_addr.s_addr & mask) == (ban_ip & mask)) banned = true; - } else { - /* No CIDR used, so just perform a simple IP test */ - if (sin.sin_addr.s_addr == inet_addr(*iter)) banned = true; - } - + banned = address.IsInNetmask(*iter); if (banned) { Packet p(PACKET_SERVER_BANNED); p.PrepareToSend(); @@ -545,7 +526,7 @@ static void NetworkAcceptClients() * the client stays inactive */ cs->status = STATUS_INACTIVE; - cs->GetInfo()->client_ip = sin.sin_addr.s_addr; // Save the IP of the client + cs->GetInfo()->client_ip = ((sockaddr_in*)&sin)->sin_addr.s_addr; // Save the IP of the client } } |