refactor: 安全性和代码质量改进

🔴 高优先级修复:

1. JWT 密钥安全 (Program.cs)
   - 移除硬编码默认密钥
   - 启动时强制检查环境变量/配置
   - 密钥长度必须 >= 32 字符

2. 数据库事务 (PortfolioService.cs)
   - CreateTransaction 添加事务保护
   - 交易创建、持仓更新、组合更新原子性保证
   - 异常时自动回滚

3. 异步方法改进 (PortfolioService.cs)
   - 移除 .GetAwaiter().GetResult() 阻塞调用
   - 统一使用 async/await 模式

🟡 中优先级:

4. 接口统一 (IPortfolioService.cs)
   - 移除同步方法,只保留异步版本
   - 简化接口,降低维护成本
This commit is contained in:
OpenClaw Agent 2026-03-25 06:35:42 +00:00
parent 42d3fc91c4
commit 02e199faf2
3 changed files with 125 additions and 137 deletions

View File

@ -56,10 +56,20 @@ builder.Services.AddCors(options =>
builder.Services.AddAuthentication(JwtBearerDefaults.AuthenticationScheme)
.AddJwtBearer(options =>
{
// 优先从环境变量读取JWT配置
// 强制从环境变量或配置文件读取JWT配置不允许硬编码默认值
var jwtSecretKey = Environment.GetEnvironmentVariable("Jwt__SecretKey")
?? builder.Configuration["Jwt:SecretKey"]
?? "your-strong-secret-key-here-2026";
?? builder.Configuration["Jwt:SecretKey"];
if (string.IsNullOrEmpty(jwtSecretKey))
{
throw new InvalidOperationException("JWT SecretKey is required. Please set Jwt__SecretKey environment variable or Jwt:SecretKey in configuration.");
}
if (jwtSecretKey.Length < 32)
{
throw new InvalidOperationException("JWT SecretKey must be at least 32 characters long for security.");
}
var jwtIssuer = Environment.GetEnvironmentVariable("Jwt__Issuer")
?? builder.Configuration["Jwt:Issuer"]
?? "AssetManager";

View File

@ -4,17 +4,12 @@ namespace AssetManager.Services;
public interface IPortfolioService
{
CreatePortfolioResponse CreatePortfolio(CreatePortfolioRequest request, string userId);
Task<CreatePortfolioResponse> CreatePortfolioAsync(CreatePortfolioRequest request, string userId);
Task<bool> UpdatePortfolioAsync(string portfolioId, UpdatePortfolioRequest request, string userId);
List<PortfolioListItem> GetPortfolios(string userId);
Task<List<PortfolioListItem>> GetPortfolioListAsync(string userId);
TotalAssetsResponse GetTotalAssets(string userId);
Task<TotalAssetsResponse> GetTotalAssetsAsync(string userId);
PortfolioDetailResponse GetPortfolioById(string id, string userId);
Task<PortfolioDetailResponse> GetPortfolioDetailAsync(string portfolioId, string userId);
GetTransactionsResponse GetTransactions(string portfolioId, string userId, int limit, int offset);
Task<List<TransactionItem>> GetTransactionsAsync(string portfolioId, GetTransactionsRequest request, string userId);
Task<GetTransactionsResponse> GetTransactionsAsync(string portfolioId, string userId, int limit, int offset);
Task<TransactionItem> CreateTransactionAsync(string portfolioId, CreateTransactionRequest request, string userId);
Task<bool> DeletePortfolioAsync(string portfolioId, string userId);
}

View File

@ -326,11 +326,6 @@ public class PortfolioService : IPortfolioService
return result;
}
public List<PortfolioListItem> GetPortfolios(string userId)
{
return GetPortfolioListAsync(userId).GetAwaiter().GetResult();
}
public async Task<TotalAssetsResponse> GetTotalAssetsAsync(string userId)
{
// 获取用户信息
@ -464,12 +459,6 @@ public class PortfolioService : IPortfolioService
};
}
// 保留同步方法作为兼容
public TotalAssetsResponse GetTotalAssets(string userId)
{
return GetTotalAssetsAsync(userId).GetAwaiter().GetResult();
}
public async Task<PortfolioDetailResponse> GetPortfolioByIdAsync(string id, string userId)
{
var portfolio = _db.Queryable<Portfolio>()
@ -592,13 +581,7 @@ public class PortfolioService : IPortfolioService
};
}
// 保留同步方法作为兼容(内部调用异步)
public PortfolioDetailResponse GetPortfolioById(string id, string userId)
{
return GetPortfolioByIdAsync(id, userId).GetAwaiter().GetResult();
}
public GetTransactionsResponse GetTransactions(string portfolioId, string userId, int limit, int offset)
public async Task<GetTransactionsResponse> GetTransactionsAsync(string portfolioId, string userId, int limit, int offset)
{
// 验证投资组合是否属于该用户
var portfolio = _db.Queryable<Portfolio>()
@ -737,112 +720,128 @@ public class PortfolioService : IPortfolioService
CreatedAt = DateTime.Now
};
_db.Insertable(transaction).ExecuteCommand();
// 更新持仓
var position = _db.Queryable<Position>()
.Where(pos => pos.PortfolioId == request.PortfolioId && pos.StockCode == request.StockCode)
.First();
if (position != null)
// 使用事务包裹所有数据库操作
try
{
if (request.Type == "buy")
_db.BeginTran();
// 1. 插入交易记录
_db.Insertable(transaction).ExecuteCommand();
// 2. 更新持仓
var position = _db.Queryable<Position>()
.Where(pos => pos.PortfolioId == request.PortfolioId && pos.StockCode == request.StockCode)
.First();
if (position != null)
{
// 计算新的平均价格和总成本
var buyAmount = (decimal)request.Amount * (decimal)request.Price;
var newTotalShares = position.Shares + (decimal)request.Amount;
var newTotalCost = position.TotalCost + buyAmount;
position.AvgPrice = newTotalCost / newTotalShares;
position.TotalCost = newTotalCost;
position.Shares = newTotalShares;
position.UpdatedAt = DateTime.Now;
_logger.LogInformation("买入更新持仓: {StockCode}, +{Amount}股@{Price}, 新成本={TotalCost}, 新均价={AvgPrice}",
position.StockCode, request.Amount, request.Price, position.TotalCost, position.AvgPrice);
_db.Updateable(position).ExecuteCommand();
}
else if (request.Type == "sell")
{
// 按比例减少成本
var sellRatio = (decimal)request.Amount / position.Shares;
var costToReduce = position.TotalCost * sellRatio;
position.Shares -= (decimal)request.Amount;
position.TotalCost -= costToReduce;
position.UpdatedAt = DateTime.Now;
_logger.LogInformation("卖出更新持仓: {StockCode}, -{Amount}股@{Price}, 减少成本={CostToReduce}, 剩余成本={TotalCost}",
position.StockCode, request.Amount, request.Price, costToReduce, position.TotalCost);
if (position.Shares <= 0)
{
_db.Deleteable(position).ExecuteCommand();
}
else
if (request.Type == "buy")
{
// 计算新的平均价格和总成本
var buyAmount = (decimal)request.Amount * (decimal)request.Price;
var newTotalShares = position.Shares + (decimal)request.Amount;
var newTotalCost = position.TotalCost + buyAmount;
position.AvgPrice = newTotalCost / newTotalShares;
position.TotalCost = newTotalCost;
position.Shares = newTotalShares;
position.UpdatedAt = DateTime.Now;
_logger.LogInformation("买入更新持仓: {StockCode}, +{Amount}股@{Price}, 新成本={TotalCost}, 新均价={AvgPrice}",
position.StockCode, request.Amount, request.Price, position.TotalCost, position.AvgPrice);
_db.Updateable(position).ExecuteCommand();
}
}
}
else if (request.Type == "buy")
{
// 创建新持仓
position = new Position
{
Id = "pos-" + Guid.NewGuid().ToString().Substring(0, 8),
PortfolioId = request.PortfolioId,
StockCode = request.StockCode,
StockName = request.Remark ?? request.StockCode,
AssetType = string.IsNullOrEmpty(request.AssetType) ? "Stock" : request.AssetType,
Shares = (decimal)request.Amount,
AvgPrice = (decimal)request.Price,
TotalCost = (decimal)(request.Price * request.Amount),
Currency = request.Currency,
CreatedAt = DateTime.Now,
UpdatedAt = DateTime.Now
};
_logger.LogInformation("创建新持仓: {StockCode}, 数量={Shares}, 均价={AvgPrice}, 成本={TotalCost}",
position.StockCode, position.Shares, position.AvgPrice, position.TotalCost);
_db.Insertable(position).ExecuteCommand();
}
// 更新投资组合总价值(使用实时市值而不是成本价)
var Positions = _db.Queryable<Position>()
.Where(pos => pos.PortfolioId == request.PortfolioId)
.ToList();
decimal totalPortfolioValue = 0;
foreach (var pos in Positions)
{
if (pos.StockCode == null)
{
continue;
}
// 获取实时价格(自动路由数据源),失败则降级使用成本价
decimal CurrentPrice = pos.AvgPrice;
try
{
var priceResponse = _marketDataService.GetPriceAsync(pos.StockCode, pos.AssetType ?? "Stock").GetAwaiter().GetResult();
if (priceResponse.Price > 0)
else if (request.Type == "sell")
{
CurrentPrice = priceResponse.Price;
// 按比例减少成本
var sellRatio = (decimal)request.Amount / position.Shares;
var costToReduce = position.TotalCost * sellRatio;
position.Shares -= (decimal)request.Amount;
position.TotalCost -= costToReduce;
position.UpdatedAt = DateTime.Now;
_logger.LogInformation("卖出更新持仓: {StockCode}, -{Amount}股@{Price}, 减少成本={CostToReduce}, 剩余成本={TotalCost}",
position.StockCode, request.Amount, request.Price, costToReduce, position.TotalCost);
if (position.Shares <= 0)
{
_db.Deleteable(position).ExecuteCommand();
}
else
{
_db.Updateable(position).ExecuteCommand();
}
}
}
catch (Exception ex)
else if (request.Type == "buy")
{
_logger.LogWarning(ex, "获取标的 {StockCode} 实时价格失败,使用成本价计算组合总价值", pos.StockCode);
// 创建新持仓
position = new Position
{
Id = "pos-" + Guid.NewGuid().ToString().Substring(0, 8),
PortfolioId = request.PortfolioId,
StockCode = request.StockCode,
StockName = request.Remark ?? request.StockCode,
AssetType = string.IsNullOrEmpty(request.AssetType) ? "Stock" : request.AssetType,
Shares = (decimal)request.Amount,
AvgPrice = (decimal)request.Price,
TotalCost = (decimal)(request.Price * request.Amount),
Currency = request.Currency,
CreatedAt = DateTime.Now,
UpdatedAt = DateTime.Now
};
_logger.LogInformation("创建新持仓: {StockCode}, 数量={Shares}, 均价={AvgPrice}, 成本={TotalCost}",
position.StockCode, position.Shares, position.AvgPrice, position.TotalCost);
_db.Insertable(position).ExecuteCommand();
}
totalPortfolioValue += pos.Shares * CurrentPrice;
}
// 3. 更新投资组合总价值
var Positions = _db.Queryable<Position>()
.Where(pos => pos.PortfolioId == request.PortfolioId)
.ToList();
portfolio.TotalValue = totalPortfolioValue;
portfolio.UpdatedAt = DateTime.Now;
_db.Updateable(portfolio).ExecuteCommand();
decimal totalPortfolioValue = 0;
foreach (var pos in Positions)
{
if (pos.StockCode == null)
{
continue;
}
// 获取实时价格(异步调用),失败则降级使用成本价
decimal CurrentPrice = pos.AvgPrice;
try
{
var priceResponse = await _marketDataService.GetPriceAsync(pos.StockCode, pos.AssetType ?? "Stock");
if (priceResponse.Price > 0)
{
CurrentPrice = priceResponse.Price;
}
}
catch (Exception ex)
{
_logger.LogWarning(ex, "获取标的 {StockCode} 实时价格失败,使用成本价计算组合总价值", pos.StockCode);
}
totalPortfolioValue += pos.Shares * CurrentPrice;
}
portfolio.TotalValue = totalPortfolioValue;
portfolio.UpdatedAt = DateTime.Now;
_db.Updateable(portfolio).ExecuteCommand();
// 提交事务
_db.CommitTran();
}
catch (Exception ex)
{
_db.RollbackTran();
_logger.LogError(ex, "创建交易失败,已回滚: {PortfolioId}, {StockCode}", request.PortfolioId, request.StockCode);
throw;
}
// 删除该交易日期之后的净值历史记录,下次请求收益曲线时会自动重新计算
try
@ -870,22 +869,6 @@ public class PortfolioService : IPortfolioService
// ===== 异步方法实现 =====
public Task<CreatePortfolioResponse> CreatePortfolioAsync(CreatePortfolioRequest request, string userId)
{
return Task.FromResult(CreatePortfolio(request, userId));
}
public Task<PortfolioDetailResponse> GetPortfolioDetailAsync(string portfolioId, string userId)
{
return GetPortfolioByIdAsync(portfolioId, userId);
}
public Task<List<TransactionItem>> GetTransactionsAsync(string portfolioId, GetTransactionsRequest request, string userId)
{
var response = GetTransactions(portfolioId, userId, request.Limit, request.Offset);
return Task.FromResult(response.Items ?? new List<TransactionItem>());
}
public async Task<TransactionItem> CreateTransactionAsync(string portfolioId, CreateTransactionRequest request, string userId)
{
request.PortfolioId = portfolioId;