using DramaLing.Api.Data; using Microsoft.EntityFrameworkCore; namespace DramaLing.Api.Services.AI.Generation; public class GenerationStateManager : IGenerationStateManager { private readonly IServiceProvider _serviceProvider; private readonly ILogger _logger; public GenerationStateManager( IServiceProvider serviceProvider, ILogger logger) { _serviceProvider = serviceProvider ?? throw new ArgumentNullException(nameof(serviceProvider)); _logger = logger ?? throw new ArgumentNullException(nameof(logger)); } public async Task UpdateRequestStatusAsync(Guid requestId, string overallStatus, string geminiStatus, string replicateStatus) { using var scope = _serviceProvider.CreateScope(); var dbContext = scope.ServiceProvider.GetRequiredService(); var request = await dbContext.ImageGenerationRequests.FindAsync(requestId); if (request == null) return; request.OverallStatus = overallStatus; request.GeminiStatus = geminiStatus; request.ReplicateStatus = replicateStatus; if (geminiStatus == "processing" && request.GeminiStartedAt == null) { request.GeminiStartedAt = DateTime.UtcNow; } if (replicateStatus == "processing" && request.ReplicateStartedAt == null) { request.ReplicateStartedAt = DateTime.UtcNow; } await dbContext.SaveChangesAsync(); } public async Task UpdateGeminiResultAsync(Guid requestId, string optimizedPrompt) { using var scope = _serviceProvider.CreateScope(); var dbContext = scope.ServiceProvider.GetRequiredService(); var request = await dbContext.ImageGenerationRequests.FindAsync(requestId); if (request == null) return; request.GeminiStatus = "completed"; request.GeminiCompletedAt = DateTime.UtcNow; request.GeneratedDescription = "Gemini generated description"; request.FinalReplicatePrompt = optimizedPrompt; request.GeminiCost = 0.002m; request.GeminiProcessingTimeMs = 30000; await dbContext.SaveChangesAsync(); } public async Task CompleteRequestAsync(Guid requestId, Guid imageId, long totalProcessingTimeMs) { using var scope = _serviceProvider.CreateScope(); var dbContext = scope.ServiceProvider.GetRequiredService(); var request = await dbContext.ImageGenerationRequests.FindAsync(requestId); if (request == null) return; request.OverallStatus = "completed"; request.ReplicateStatus = "completed"; request.GeneratedImageId = imageId; request.CompletedAt = DateTime.UtcNow; request.ReplicateCompletedAt = DateTime.UtcNow; request.TotalProcessingTimeMs = (int)totalProcessingTimeMs; request.TotalCost = (request.GeminiCost ?? 0) + (request.ReplicateCost ?? 0); await dbContext.SaveChangesAsync(); } public async Task MarkRequestAsFailedAsync(Guid requestId, string stage, string? errorMessage) { using var scope = _serviceProvider.CreateScope(); var dbContext = scope.ServiceProvider.GetRequiredService(); var request = await dbContext.ImageGenerationRequests.FindAsync(requestId); if (request == null) return; request.OverallStatus = "failed"; switch (stage.ToLower()) { case "gemini": request.GeminiStatus = "failed"; request.GeminiErrorMessage = errorMessage; request.GeminiCompletedAt = DateTime.UtcNow; break; case "replicate": request.ReplicateStatus = "failed"; request.ReplicateErrorMessage = errorMessage; request.ReplicateCompletedAt = DateTime.UtcNow; break; default: request.GeminiErrorMessage = errorMessage; request.ReplicateErrorMessage = errorMessage; break; } request.CompletedAt = DateTime.UtcNow; await dbContext.SaveChangesAsync(); _logger.LogError("Generation request {RequestId} marked as failed at stage {Stage}: {Error}", requestId, stage, errorMessage); } }