diff --git a/AssetManager.API/Middleware/RateLimitMiddleware.cs b/AssetManager.API/Middleware/RateLimitMiddleware.cs new file mode 100644 index 0000000..e33e7f2 --- /dev/null +++ b/AssetManager.API/Middleware/RateLimitMiddleware.cs @@ -0,0 +1,68 @@ +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.Logging; +using System; +using System.Threading.Tasks; + +namespace AssetManager.API.Middleware; + +/// +/// 请求限流中间件:限制每个用户的请求频率,保护第三方API配额 +/// +public class RateLimitMiddleware +{ + private readonly RequestDelegate _next; + private readonly IMemoryCache _cache; + private readonly ILogger _logger; + + // 限流配置:每个用户每分钟最多请求60次(平均1秒1次,足够正常使用) + private const int Limit = 60; + private const int WindowSeconds = 60; + private const string CacheKeyPrefix = "RateLimit_"; + + public RateLimitMiddleware(RequestDelegate next, IMemoryCache cache, ILogger logger) + { + _next = next; + _cache = cache; + _logger = logger; + } + + public async Task InvokeAsync(HttpContext context) + { + // 获取用户ID(未登录则用IP地址) + var userId = context.User.FindFirst(System.Security.Claims.ClaimTypes.NameIdentifier)?.Value + ?? context.Connection.RemoteIpAddress?.ToString() + ?? "anonymous"; + + var cacheKey = $"{CacheKeyPrefix}{userId}"; + + // 获取当前窗口的请求计数 + var requestCount = _cache.Get(cacheKey); + + if (requestCount >= Limit) + { + _logger.LogWarning("用户 {UserId} 请求超过限流限制,当前计数: {Count}", userId, requestCount); + context.Response.StatusCode = StatusCodes.Status429TooManyRequests; + context.Response.ContentType = "application/json"; + await context.Response.WriteAsync(System.Text.Json.JsonSerializer.Serialize(new + { + code = 429, + message = "请求过于频繁,请稍后再试", + data = (object)null + })); + return; + } + + // 计数+1,首次访问设置过期时间 + if (requestCount == 0) + { + _cache.Set(cacheKey, 1, TimeSpan.FromSeconds(WindowSeconds)); + } + else + { + _cache.Set(cacheKey, requestCount + 1, _cache.GetEntry(cacheKey).AbsoluteExpirationRelativeToNow ?? TimeSpan.FromSeconds(WindowSeconds)); + } + + await _next(context); + } +} diff --git a/AssetManager.API/Program.cs b/AssetManager.API/Program.cs index b564026..946ac26 100644 --- a/AssetManager.API/Program.cs +++ b/AssetManager.API/Program.cs @@ -118,6 +118,9 @@ app.Services.InitializeDatabase(); // 全局异常处理中间件(必须放在最前面) app.UseMiddleware(); +// 请求限流中间件:限制用户请求频率,保护第三方API配额 +app.UseMiddleware(); + if (app.Environment.IsDevelopment()) { app.UseSwagger();