dramaling-vocab-learning/backend/DramaLing.Api/Services/AI/Generation/GenerationStateManager.cs

116 lines
4.4 KiB
C#

using DramaLing.Api.Data;
using Microsoft.EntityFrameworkCore;
namespace DramaLing.Api.Services.AI.Generation;
public class GenerationStateManager : IGenerationStateManager
{
private readonly IServiceProvider _serviceProvider;
private readonly ILogger<GenerationStateManager> _logger;
public GenerationStateManager(
IServiceProvider serviceProvider,
ILogger<GenerationStateManager> logger)
{
_serviceProvider = serviceProvider ?? throw new ArgumentNullException(nameof(serviceProvider));
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
}
public async Task UpdateRequestStatusAsync(Guid requestId, string overallStatus, string geminiStatus, string replicateStatus)
{
using var scope = _serviceProvider.CreateScope();
var dbContext = scope.ServiceProvider.GetRequiredService<DramaLingDbContext>();
var request = await dbContext.ImageGenerationRequests.FindAsync(requestId);
if (request == null) return;
request.OverallStatus = overallStatus;
request.GeminiStatus = geminiStatus;
request.ReplicateStatus = replicateStatus;
if (geminiStatus == "processing" && request.GeminiStartedAt == null)
{
request.GeminiStartedAt = DateTime.UtcNow;
}
if (replicateStatus == "processing" && request.ReplicateStartedAt == null)
{
request.ReplicateStartedAt = DateTime.UtcNow;
}
await dbContext.SaveChangesAsync();
}
public async Task UpdateGeminiResultAsync(Guid requestId, string optimizedPrompt)
{
using var scope = _serviceProvider.CreateScope();
var dbContext = scope.ServiceProvider.GetRequiredService<DramaLingDbContext>();
var request = await dbContext.ImageGenerationRequests.FindAsync(requestId);
if (request == null) return;
request.GeminiStatus = "completed";
request.GeminiCompletedAt = DateTime.UtcNow;
request.GeneratedDescription = "Gemini generated description";
request.FinalReplicatePrompt = optimizedPrompt;
request.GeminiCost = 0.002m;
request.GeminiProcessingTimeMs = 30000;
await dbContext.SaveChangesAsync();
}
public async Task CompleteRequestAsync(Guid requestId, Guid imageId, long totalProcessingTimeMs)
{
using var scope = _serviceProvider.CreateScope();
var dbContext = scope.ServiceProvider.GetRequiredService<DramaLingDbContext>();
var request = await dbContext.ImageGenerationRequests.FindAsync(requestId);
if (request == null) return;
request.OverallStatus = "completed";
request.ReplicateStatus = "completed";
request.GeneratedImageId = imageId;
request.CompletedAt = DateTime.UtcNow;
request.ReplicateCompletedAt = DateTime.UtcNow;
request.TotalProcessingTimeMs = (int)totalProcessingTimeMs;
request.TotalCost = (request.GeminiCost ?? 0) + (request.ReplicateCost ?? 0);
await dbContext.SaveChangesAsync();
}
public async Task MarkRequestAsFailedAsync(Guid requestId, string stage, string? errorMessage)
{
using var scope = _serviceProvider.CreateScope();
var dbContext = scope.ServiceProvider.GetRequiredService<DramaLingDbContext>();
var request = await dbContext.ImageGenerationRequests.FindAsync(requestId);
if (request == null) return;
request.OverallStatus = "failed";
switch (stage.ToLower())
{
case "gemini":
request.GeminiStatus = "failed";
request.GeminiErrorMessage = errorMessage;
request.GeminiCompletedAt = DateTime.UtcNow;
break;
case "replicate":
request.ReplicateStatus = "failed";
request.ReplicateErrorMessage = errorMessage;
request.ReplicateCompletedAt = DateTime.UtcNow;
break;
default:
request.GeminiErrorMessage = errorMessage;
request.ReplicateErrorMessage = errorMessage;
break;
}
request.CompletedAt = DateTime.UtcNow;
await dbContext.SaveChangesAsync();
_logger.LogError("Generation request {RequestId} marked as failed at stage {Stage}: {Error}",
requestId, stage, errorMessage);
}
}