dramaling-vocab-learning/backend/DramaLing.Api/Services/AI/AIProviderManager.cs

260 lines
9.4 KiB
C#

using DramaLing.Api.Models.DTOs;
namespace DramaLing.Api.Services.AI;
/// <summary>
/// AI 提供商管理器實作
/// </summary>
public class AIProviderManager : IAIProviderManager
{
private readonly IEnumerable<IAIProvider> _providers;
private readonly ILogger<AIProviderManager> _logger;
private readonly Random _random = new();
public AIProviderManager(IEnumerable<IAIProvider> providers, ILogger<AIProviderManager> logger)
{
_providers = providers ?? throw new ArgumentNullException(nameof(providers));
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
_logger.LogInformation("AIProviderManager initialized with {ProviderCount} providers: {ProviderNames}",
_providers.Count(), string.Join(", ", _providers.Select(p => p.ProviderName)));
}
public async Task<IAIProvider> GetBestProviderAsync(ProviderSelectionStrategy strategy = ProviderSelectionStrategy.Performance)
{
var availableProviders = await GetAvailableProvidersAsync();
if (!availableProviders.Any())
{
throw new InvalidOperationException("No AI providers are available");
}
var selectedProvider = strategy switch
{
ProviderSelectionStrategy.Performance => await SelectByPerformanceAsync(availableProviders),
ProviderSelectionStrategy.Cost => SelectByCost(availableProviders),
ProviderSelectionStrategy.Reliability => await SelectByReliabilityAsync(availableProviders),
ProviderSelectionStrategy.LoadBalance => SelectByLoadBalance(availableProviders),
ProviderSelectionStrategy.Primary => SelectPrimary(availableProviders),
_ => availableProviders.First()
};
_logger.LogDebug("Selected AI provider: {ProviderName} using strategy: {Strategy}",
selectedProvider.ProviderName, strategy);
return selectedProvider;
}
public async Task<IEnumerable<IAIProvider>> GetAvailableProvidersAsync()
{
var availableProviders = new List<IAIProvider>();
foreach (var provider in _providers)
{
try
{
if (provider.IsAvailable)
{
var healthStatus = await provider.CheckHealthAsync();
if (healthStatus.IsHealthy)
{
availableProviders.Add(provider);
}
else
{
_logger.LogWarning("Provider {ProviderName} is not healthy: {Error}",
provider.ProviderName, healthStatus.ErrorMessage);
}
}
else
{
_logger.LogWarning("Provider {ProviderName} is not available", provider.ProviderName);
}
}
catch (Exception ex)
{
_logger.LogError(ex, "Error checking provider {ProviderName} availability", provider.ProviderName);
}
}
return availableProviders;
}
public async Task<IAIProvider?> GetProviderByNameAsync(string providerName)
{
var provider = _providers.FirstOrDefault(p => p.ProviderName.Equals(providerName, StringComparison.OrdinalIgnoreCase));
if (provider != null && provider.IsAvailable)
{
try
{
var healthStatus = await provider.CheckHealthAsync();
if (healthStatus.IsHealthy)
{
return provider;
}
}
catch (Exception ex)
{
_logger.LogError(ex, "Error checking provider {ProviderName} health", providerName);
}
}
return null;
}
public async Task<ProviderHealthReport> CheckAllProvidersHealthAsync()
{
var report = new ProviderHealthReport
{
CheckedAt = DateTime.UtcNow,
TotalProviders = _providers.Count()
};
var healthTasks = _providers.Select(async provider =>
{
try
{
var healthStatus = await provider.CheckHealthAsync();
var stats = await provider.GetStatsAsync();
return new ProviderHealthInfo
{
ProviderName = provider.ProviderName,
IsHealthy = healthStatus.IsHealthy,
ResponseTimeMs = healthStatus.ResponseTimeMs,
ErrorMessage = healthStatus.ErrorMessage,
Stats = stats
};
}
catch (Exception ex)
{
_logger.LogError(ex, "Error checking health for provider {ProviderName}", provider.ProviderName);
return new ProviderHealthInfo
{
ProviderName = provider.ProviderName,
IsHealthy = false,
ErrorMessage = ex.Message,
Stats = new AIProviderStats()
};
}
});
report.ProviderHealthInfos = (await Task.WhenAll(healthTasks)).ToList();
report.HealthyProviders = report.ProviderHealthInfos.Count(p => p.IsHealthy);
return report;
}
public async Task<SentenceAnalysisData> AnalyzeSentenceAsync(string inputText, AnalysisOptions options,
ProviderSelectionStrategy strategy = ProviderSelectionStrategy.Performance)
{
var provider = await GetBestProviderAsync(strategy);
try
{
var result = await provider.AnalyzeSentenceAsync(inputText, options);
_logger.LogInformation("Sentence analyzed successfully using provider: {ProviderName}", provider.ProviderName);
return result;
}
catch (Exception ex)
{
_logger.LogError(ex, "Error analyzing sentence with provider {ProviderName}, attempting fallback",
provider.ProviderName);
// 嘗試使用其他可用的提供商
var availableProviders = (await GetAvailableProvidersAsync())
.Where(p => p.ProviderName != provider.ProviderName)
.ToList();
foreach (var fallbackProvider in availableProviders)
{
try
{
var result = await fallbackProvider.AnalyzeSentenceAsync(inputText, options);
_logger.LogWarning("Fallback successful using provider: {ProviderName}", fallbackProvider.ProviderName);
return result;
}
catch (Exception fallbackEx)
{
_logger.LogError(fallbackEx, "Fallback provider {ProviderName} also failed", fallbackProvider.ProviderName);
}
}
// 如果所有提供商都失敗,重新拋出原始異常
throw;
}
}
#region
private async Task<IAIProvider> SelectByPerformanceAsync(IEnumerable<IAIProvider> providers)
{
var providerList = providers.ToList();
var performanceData = new List<(IAIProvider Provider, int ResponseTime)>();
foreach (var provider in providerList)
{
try
{
var stats = await provider.GetStatsAsync();
performanceData.Add((provider, stats.AverageResponseTimeMs));
}
catch (Exception ex)
{
_logger.LogWarning(ex, "Could not get stats for provider {ProviderName}", provider.ProviderName);
performanceData.Add((provider, int.MaxValue)); // 最低優先級
}
}
return performanceData
.OrderBy(p => p.ResponseTime)
.First().Provider;
}
private IAIProvider SelectByCost(IEnumerable<IAIProvider> providers)
{
return providers
.OrderBy(p => p.CostPerRequest)
.First();
}
private async Task<IAIProvider> SelectByReliabilityAsync(IEnumerable<IAIProvider> providers)
{
var providerList = providers.ToList();
var reliabilityData = new List<(IAIProvider Provider, double SuccessRate)>();
foreach (var provider in providerList)
{
try
{
var stats = await provider.GetStatsAsync();
reliabilityData.Add((provider, stats.SuccessRate));
}
catch (Exception ex)
{
_logger.LogWarning(ex, "Could not get stats for provider {ProviderName}", provider.ProviderName);
reliabilityData.Add((provider, 0.0)); // 最低優先級
}
}
return reliabilityData
.OrderByDescending(p => p.SuccessRate)
.First().Provider;
}
private IAIProvider SelectByLoadBalance(IEnumerable<IAIProvider> providers)
{
var providerList = providers.ToList();
var randomIndex = _random.Next(providerList.Count);
return providerList[randomIndex];
}
private IAIProvider SelectPrimary(IEnumerable<IAIProvider> providers)
{
// 使用第一個可用的提供商作為主要提供商
return providers.First();
}
#endregion
}