using System.Net; using System.Text.RegularExpressions; using bitforum.Services; namespace bitforum.Middleware { public class IPFilter { private readonly RequestDelegate _next; private readonly ISetupService _setupService; public IPFilter(RequestDelegate next, ISetupService setupService) { _next = next; _setupService = setupService; } public async Task Invoke(HttpContext context) { var remoteIP = context.Connection.RemoteIpAddress; // IP가 허용되지 않으면 403 Forbidden 반환 if (!IsIpAllowed(remoteIP)) { context.Response.StatusCode = StatusCodes.Status403Forbidden; await context.Response.WriteAsync("Access Denied: Your IP is not allowed to access this resource."); return; } await _next(context); } private bool IsIpAllowed(IPAddress? remoteIP) { if (remoteIP == null) { return false; } var adminWhiteIpList = _setupService.GetConfig("admin_white_ip_list"); var allowedIPs = string.IsNullOrWhiteSpace(adminWhiteIpList) ? new List() : adminWhiteIpList.Split(['\r', '\n'], StringSplitOptions.RemoveEmptyEntries).Select(ip => ip.Trim()).ToList(); // 허용된 IP가 없으면 기본적으로 모든 요청 허용 if (!allowedIPs.Any()) { return true; } var ipString = remoteIP.ToString(); foreach (var allowedIP in allowedIPs) { if (allowedIP.Contains("*")) { var regex = "^" + Regex.Escape(allowedIP).Replace("\\*", ".*") + "$"; if (Regex.IsMatch(ipString, regex)) { return true; } } else if (allowedIP.Contains("-")) { var parts = allowedIP.Split('-'); if (parts.Length == 2 && IPAddress.TryParse(parts[0].Trim(), out var startIp) && IPAddress.TryParse(parts[1].Trim(), out var endIp) && IsInRange(remoteIP, startIp, endIp)) { return true; } } else if (IPAddress.TryParse(allowedIP, out var allowedAddress) && remoteIP.Equals(allowedAddress)) { return true; } } return false; } private bool IsInRange(IPAddress IP, IPAddress startIP, IPAddress endIP) { var ipBytes = IP.GetAddressBytes(); var startBytes = startIP.GetAddressBytes(); var endBytes = endIP.GetAddressBytes(); if (ipBytes.Length != startBytes.Length || ipBytes.Length != endBytes.Length) { return false; } for (int i = 0; i < ipBytes.Length; i++) { if (ipBytes[i] < startBytes[i] || ipBytes[i] > endBytes[i]) { return false; } } return true; } } }