dramaling-vocab-learning/backend/DramaLing.Api/Middleware/SecurityMiddleware.cs

319 lines
10 KiB
C#

using System.Text.RegularExpressions;
using System.Collections.Concurrent;
using System.Text.Json;
namespace DramaLing.Api.Middleware;
/// <summary>
/// 安全中間件,提供輸入驗證、速率限制和安全檢查
/// </summary>
public class SecurityMiddleware
{
private readonly RequestDelegate _next;
private readonly ILogger<SecurityMiddleware> _logger;
private readonly SecurityOptions _options;
// 簡單的記憶體速率限制器
private static readonly ConcurrentDictionary<string, ClientRateLimit> _rateLimits = new();
// 惡意模式檢測
private static readonly Regex[] SuspiciousPatterns = new[]
{
new Regex(@"<script\b[^<]*(?:(?!<\/script>)<[^<]*)*<\/script>", RegexOptions.IgnoreCase | RegexOptions.Compiled),
new Regex(@"(\bUNION\b|\bSELECT\b|\bINSERT\b|\bDELETE\b|\bDROP\b)", RegexOptions.IgnoreCase | RegexOptions.Compiled),
new Regex(@"(javascript:|data:|vbscript:)", RegexOptions.IgnoreCase | RegexOptions.Compiled),
new Regex(@"(\.\./|\.\.\\)", RegexOptions.Compiled), // 路徑遍歷
new Regex(@"(eval\(|exec\(|system\()", RegexOptions.IgnoreCase | RegexOptions.Compiled)
};
public SecurityMiddleware(RequestDelegate next, ILogger<SecurityMiddleware> logger, SecurityOptions? options = null)
{
_next = next ?? throw new ArgumentNullException(nameof(next));
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
_options = options ?? new SecurityOptions();
}
public async Task InvokeAsync(HttpContext context)
{
var clientId = GetClientIdentifier(context);
var requestId = context.TraceIdentifier;
try
{
// 1. 速率限制檢查
if (!await CheckRateLimitAsync(clientId, requestId))
{
await RespondWithRateLimitExceeded(context);
return;
}
// 2. 輸入安全驗證
if (!await ValidateInputSafetyAsync(context, requestId))
{
await RespondWithSecurityViolation(context, "惡意輸入檢測");
return;
}
// 3. 請求大小檢查
if (!ValidateRequestSize(context))
{
await RespondWithSecurityViolation(context, "請求大小超過限制");
return;
}
// 4. 新增安全標頭
AddSecurityHeaders(context);
// 記錄安全事件
using var scope = _logger.BeginScope(new Dictionary<string, object>
{
["RequestId"] = requestId,
["ClientId"] = clientId,
["Method"] = context.Request.Method,
["Path"] = context.Request.Path,
["UserAgent"] = context.Request.Headers.UserAgent.ToString()
});
await _next(context);
}
catch (Exception ex)
{
_logger.LogError(ex, "Security middleware error for request {RequestId}", requestId);
throw; // 讓其他中間件處理異常
}
}
#region
private Task<bool> CheckRateLimitAsync(string clientId, string requestId)
{
try
{
var now = DateTime.UtcNow;
var clientLimit = _rateLimits.GetOrAdd(clientId, _ => new ClientRateLimit());
// 清理過期的請求記錄
clientLimit.Requests.RemoveAll(r => now - r > _options.RateLimitWindow);
// 檢查是否超過速率限制
if (clientLimit.Requests.Count >= _options.MaxRequestsPerWindow)
{
_logger.LogWarning("Rate limit exceeded for client {ClientId}, request {RequestId}",
clientId, requestId);
return Task.FromResult(false);
}
// 記錄此次請求
clientLimit.Requests.Add(now);
return Task.FromResult(true);
}
catch (Exception ex)
{
_logger.LogError(ex, "Error checking rate limit for client {ClientId}", clientId);
return Task.FromResult(true); // 錯誤時允許通過,避免服務中斷
}
}
#endregion
#region
private async Task<bool> ValidateInputSafetyAsync(HttpContext context, string requestId)
{
try
{
if (context.Request.Method != "POST" && context.Request.Method != "PUT")
{
return true; // 只檢查可能包含輸入的請求
}
var body = await ReadRequestBodyAsync(context);
if (string.IsNullOrEmpty(body))
{
return true;
}
// 檢查惡意模式
foreach (var pattern in SuspiciousPatterns)
{
if (pattern.IsMatch(body))
{
_logger.LogWarning("Suspicious pattern detected in request {RequestId}: {Pattern}",
requestId, pattern.ToString());
return false;
}
}
// 檢查過長的輸入
if (body.Length > _options.MaxInputLength)
{
_logger.LogWarning("Input too long in request {RequestId}: {Length} characters",
requestId, body.Length);
return false;
}
return true;
}
catch (Exception ex)
{
_logger.LogError(ex, "Error validating input safety for request {RequestId}", requestId);
return true; // 錯誤時允許通過
}
}
private async Task<string> ReadRequestBodyAsync(HttpContext context)
{
try
{
context.Request.EnableBuffering();
using var reader = new StreamReader(context.Request.Body, leaveOpen: true);
var body = await reader.ReadToEndAsync();
context.Request.Body.Position = 0;
return body;
}
catch
{
return string.Empty;
}
}
#endregion
#region
private bool ValidateRequestSize(HttpContext context)
{
var contentLength = context.Request.ContentLength;
if (contentLength.HasValue && contentLength.Value > _options.MaxRequestSize)
{
_logger.LogWarning("Request size {Size} exceeds limit {Limit} for {Path}",
contentLength.Value, _options.MaxRequestSize, context.Request.Path);
return false;
}
return true;
}
#endregion
#region
private void AddSecurityHeaders(HttpContext context)
{
var response = context.Response;
if (!response.Headers.ContainsKey("X-Content-Type-Options"))
response.Headers.Append("X-Content-Type-Options", "nosniff");
if (!response.Headers.ContainsKey("X-Frame-Options"))
response.Headers.Append("X-Frame-Options", "DENY");
if (!response.Headers.ContainsKey("X-XSS-Protection"))
response.Headers.Append("X-XSS-Protection", "1; mode=block");
if (!response.Headers.ContainsKey("Referrer-Policy"))
response.Headers.Append("Referrer-Policy", "strict-origin-when-cross-origin");
}
#endregion
#region
private string GetClientIdentifier(HttpContext context)
{
// 使用 IP 地址作為客戶端識別
var ipAddress = context.Connection.RemoteIpAddress?.ToString() ?? "unknown";
var userAgent = context.Request.Headers.UserAgent.ToString();
// 可以加入更複雜的指紋識別邏輯
return $"{ipAddress}_{userAgent.GetHashCode()}";
}
private async Task RespondWithRateLimitExceeded(HttpContext context)
{
context.Response.StatusCode = 429;
context.Response.ContentType = "application/json";
var response = new
{
Success = false,
Error = new
{
Code = "RATE_LIMIT_EXCEEDED",
Message = "請求過於頻繁,請稍後再試",
RetryAfter = _options.RateLimitWindow.TotalSeconds
},
Timestamp = DateTime.UtcNow
};
var json = JsonSerializer.Serialize(response, new JsonSerializerOptions
{
PropertyNamingPolicy = JsonNamingPolicy.CamelCase
});
await context.Response.WriteAsync(json);
}
private async Task RespondWithSecurityViolation(HttpContext context, string reason)
{
context.Response.StatusCode = 400;
context.Response.ContentType = "application/json";
var response = new
{
Success = false,
Error = new
{
Code = "SECURITY_VIOLATION",
Message = "安全檢查失敗",
Reason = reason
},
Timestamp = DateTime.UtcNow
};
var json = JsonSerializer.Serialize(response, new JsonSerializerOptions
{
PropertyNamingPolicy = JsonNamingPolicy.CamelCase
});
await context.Response.WriteAsync(json);
}
#endregion
}
/// <summary>
/// 安全中間件配置選項
/// </summary>
public class SecurityOptions
{
/// <summary>
/// 速率限制時間窗口
/// </summary>
public TimeSpan RateLimitWindow { get; set; } = TimeSpan.FromMinutes(1);
/// <summary>
/// 時間窗口內最大請求數
/// </summary>
public int MaxRequestsPerWindow { get; set; } = 60;
/// <summary>
/// 最大輸入長度
/// </summary>
public int MaxInputLength { get; set; } = 10000;
/// <summary>
/// 最大請求大小(字節)
/// </summary>
public long MaxRequestSize { get; set; } = 1024 * 1024; // 1MB
}
/// <summary>
/// 客戶端速率限制資訊
/// </summary>
public class ClientRateLimit
{
public List<DateTime> Requests { get; set; } = new();
}