using DramaLing.Api.Data; using DramaLing.Api.Models.DTOs; using DramaLing.Api.Models.Entities; using DramaLing.Api.Services.AI; using DramaLing.Api.Services.Storage; using Microsoft.EntityFrameworkCore; using System.Diagnostics; using System.Text.Json; namespace DramaLing.Api.Services; public class ImageGenerationOrchestrator : IImageGenerationOrchestrator { private readonly IGeminiImageDescriptionService _geminiService; private readonly IReplicateImageGenerationService _replicateService; private readonly IImageStorageService _storageService; private readonly DramaLingDbContext _dbContext; private readonly ILogger _logger; public ImageGenerationOrchestrator( IGeminiImageDescriptionService geminiService, IReplicateImageGenerationService replicateService, IImageStorageService storageService, DramaLingDbContext dbContext, ILogger logger) { _geminiService = geminiService ?? throw new ArgumentNullException(nameof(geminiService)); _replicateService = replicateService ?? throw new ArgumentNullException(nameof(replicateService)); _storageService = storageService ?? throw new ArgumentNullException(nameof(storageService)); _dbContext = dbContext ?? throw new ArgumentNullException(nameof(dbContext)); _logger = logger ?? throw new ArgumentNullException(nameof(logger)); } public async Task StartGenerationAsync(Guid flashcardId, GenerationRequest request) { try { // 檢查詞卡是否存在 var flashcard = await _dbContext.Flashcards.FindAsync(flashcardId); if (flashcard == null) { throw new ArgumentException($"Flashcard {flashcardId} not found"); } // 建立生成請求記錄 var generationRequest = new ImageGenerationRequest { Id = Guid.NewGuid(), UserId = request.UserId, FlashcardId = flashcardId, OverallStatus = "pending", GeminiStatus = "pending", ReplicateStatus = "pending", OriginalRequest = JsonSerializer.Serialize(request), CreatedAt = DateTime.UtcNow }; _dbContext.ImageGenerationRequests.Add(generationRequest); await _dbContext.SaveChangesAsync(); _logger.LogInformation("Created generation request {RequestId} for flashcard {FlashcardId}", generationRequest.Id, flashcardId); // 後台執行兩階段生成流程 _ = Task.Run(async () => await ExecuteGenerationPipelineAsync(generationRequest.Id)); return new GenerationRequestResult { RequestId = generationRequest.Id, OverallStatus = "pending", CurrentStage = "description_generation", EstimatedTimeMinutes = new EstimatedTimeDto { Gemini = 0.5, Replicate = 2.0, Total = 2.5 }, CostEstimate = new CostEstimateDto { Gemini = 0.002m, Replicate = 0.025m, Total = 0.027m } }; } catch (Exception ex) { _logger.LogError(ex, "Failed to start generation for flashcard {FlashcardId}", flashcardId); throw; } } public async Task GetGenerationStatusAsync(Guid requestId) { var request = await _dbContext.ImageGenerationRequests .Include(r => r.GeneratedImage) .FirstOrDefaultAsync(r => r.Id == requestId); if (request == null) { throw new ArgumentException($"Generation request {requestId} not found"); } return new GenerationStatusResponse { RequestId = request.Id, OverallStatus = request.OverallStatus, Stages = new StageStatusDto { Gemini = new GeminiStageDto { Status = request.GeminiStatus, StartedAt = request.GeminiStartedAt, CompletedAt = request.GeminiCompletedAt, ProcessingTimeMs = request.GeminiProcessingTimeMs, Cost = request.GeminiCost, GeneratedDescription = request.GeneratedDescription }, Replicate = new ReplicateStageDto { Status = request.ReplicateStatus, StartedAt = request.ReplicateStartedAt, CompletedAt = request.ReplicateCompletedAt, ProcessingTimeMs = request.ReplicateProcessingTimeMs, Cost = request.ReplicateCost } }, TotalCost = request.TotalCost, CompletedAt = request.CompletedAt, Result = request.GeneratedImage != null ? new GenerationResultDto { ImageUrl = await _storageService.GetImageUrlAsync(request.GeneratedImage.RelativePath), ImageId = request.GeneratedImage.Id.ToString(), QualityScore = request.GeneratedImage.QualityScore, Dimensions = new DimensionsDto { Width = request.GeneratedImage.ImageWidth ?? 512, Height = request.GeneratedImage.ImageHeight ?? 512 }, FileSize = request.GeneratedImage.FileSize } : null }; } public async Task CancelGenerationAsync(Guid requestId) { try { var request = await _dbContext.ImageGenerationRequests.FindAsync(requestId); if (request == null || request.OverallStatus == "completed") { return false; } request.OverallStatus = "cancelled"; await _dbContext.SaveChangesAsync(); _logger.LogInformation("Generation request {RequestId} cancelled", requestId); return true; } catch (Exception ex) { _logger.LogError(ex, "Failed to cancel generation request {RequestId}", requestId); return false; } } private async Task ExecuteGenerationPipelineAsync(Guid requestId) { var totalStopwatch = Stopwatch.StartNew(); try { var request = await _dbContext.ImageGenerationRequests .Include(r => r.Flashcard) .FirstOrDefaultAsync(r => r.Id == requestId); if (request == null) { _logger.LogError("Generation request {RequestId} not found in pipeline", requestId); return; } var options = JsonSerializer.Deserialize(request.OriginalRequest); // 第一階段:Gemini 描述生成 _logger.LogInformation("Starting Gemini description generation for request {RequestId}", requestId); await UpdateRequestStatusAsync(requestId, "description_generating", "processing", "pending"); var descriptionResult = await _geminiService.GenerateDescriptionAsync( request.Flashcard, options?.Options ?? new GenerationOptionsDto()); if (!descriptionResult.Success) { await MarkRequestAsFailedAsync(requestId, "gemini", descriptionResult.Error); return; } // 更新 Gemini 結果 await UpdateGeminiResultAsync(requestId, descriptionResult); // 第二階段:Replicate 圖片生成 _logger.LogInformation("Starting Replicate image generation for request {RequestId}", requestId); await UpdateRequestStatusAsync(requestId, "image_generating", "completed", "processing"); var imageResult = await _replicateService.GenerateImageAsync( descriptionResult.OptimizedPrompt ?? descriptionResult.Description ?? "", options?.ReplicateModel ?? "ideogram-v2a-turbo", options?.Options ?? new GenerationOptionsDto()); if (!imageResult.Success) { await MarkRequestAsFailedAsync(requestId, "replicate", imageResult.Error); return; } // 下載並儲存圖片 var savedImage = await SaveGeneratedImageAsync(request, descriptionResult, imageResult); // 完成請求 await CompleteRequestAsync(requestId, savedImage.Id, totalStopwatch.ElapsedMilliseconds); _logger.LogInformation("Generation pipeline completed successfully for request {RequestId} in {ElapsedMs}ms", requestId, totalStopwatch.ElapsedMilliseconds); } catch (Exception ex) { totalStopwatch.Stop(); _logger.LogError(ex, "Generation pipeline failed for request {RequestId}", requestId); await MarkRequestAsFailedAsync(requestId, "system", ex.Message); } } private async Task UpdateRequestStatusAsync(Guid requestId, string overallStatus, string geminiStatus, string replicateStatus) { var request = await _dbContext.ImageGenerationRequests.FindAsync(requestId); if (request == null) return; request.OverallStatus = overallStatus; request.GeminiStatus = geminiStatus; request.ReplicateStatus = replicateStatus; if (geminiStatus == "processing" && request.GeminiStartedAt == null) { request.GeminiStartedAt = DateTime.UtcNow; } if (replicateStatus == "processing" && request.ReplicateStartedAt == null) { request.ReplicateStartedAt = DateTime.UtcNow; } await _dbContext.SaveChangesAsync(); } private async Task UpdateGeminiResultAsync(Guid requestId, ImageDescriptionResult result) { var request = await _dbContext.ImageGenerationRequests.FindAsync(requestId); if (request == null) return; request.GeminiStatus = "completed"; request.GeminiCompletedAt = DateTime.UtcNow; request.GeneratedDescription = result.Description; request.FinalReplicatePrompt = result.OptimizedPrompt; request.GeminiCost = result.Cost; request.GeminiProcessingTimeMs = result.ProcessingTimeMs; await _dbContext.SaveChangesAsync(); } private async Task SaveGeneratedImageAsync( ImageGenerationRequest request, ImageDescriptionResult descriptionResult, ImageGenerationResult imageResult) { // 下載圖片 using var httpClient = new HttpClient(); var imageBytes = await httpClient.GetByteArrayAsync(imageResult.ImageUrl); var imageStream = new MemoryStream(imageBytes); // 生成檔案名稱 var fileName = $"{request.FlashcardId}_{Guid.NewGuid()}.png"; // 儲存到本地/雲端 var relativePath = await _storageService.SaveImageAsync(imageStream, fileName); // 建立 ExampleImage 記錄 var exampleImage = new ExampleImage { Id = Guid.NewGuid(), RelativePath = relativePath, AltText = $"Example image for {request.Flashcard?.Word}", GeminiPrompt = request.GeminiPrompt, GeminiDescription = descriptionResult.Description, ReplicatePrompt = descriptionResult.OptimizedPrompt, ReplicateModel = "ideogram-v2a-turbo", GeminiCost = descriptionResult.Cost, ReplicateCost = imageResult.Cost, TotalGenerationCost = descriptionResult.Cost + imageResult.Cost, FileSize = imageBytes.Length, ImageWidth = 512, ImageHeight = 512, ContentHash = ComputeHash(imageBytes), ModerationStatus = "pending", CreatedAt = DateTime.UtcNow, UpdatedAt = DateTime.UtcNow }; _dbContext.ExampleImages.Add(exampleImage); // 建立詞卡圖片關聯 var flashcardImage = new FlashcardExampleImage { FlashcardId = request.FlashcardId, ExampleImageId = exampleImage.Id, DisplayOrder = 1, IsPrimary = true, ContextRelevance = 1.0m, CreatedAt = DateTime.UtcNow }; _dbContext.FlashcardExampleImages.Add(flashcardImage); await _dbContext.SaveChangesAsync(); return exampleImage; } private async Task CompleteRequestAsync(Guid requestId, Guid imageId, long totalProcessingTimeMs) { var request = await _dbContext.ImageGenerationRequests.FindAsync(requestId); if (request == null) return; request.OverallStatus = "completed"; request.ReplicateStatus = "completed"; request.GeneratedImageId = imageId; request.CompletedAt = DateTime.UtcNow; request.ReplicateCompletedAt = DateTime.UtcNow; request.TotalProcessingTimeMs = (int)totalProcessingTimeMs; request.TotalCost = (request.GeminiCost ?? 0) + (request.ReplicateCost ?? 0); await _dbContext.SaveChangesAsync(); } private async Task MarkRequestAsFailedAsync(Guid requestId, string stage, string? errorMessage) { var request = await _dbContext.ImageGenerationRequests.FindAsync(requestId); if (request == null) return; request.OverallStatus = "failed"; switch (stage.ToLower()) { case "gemini": request.GeminiStatus = "failed"; request.GeminiErrorMessage = errorMessage; request.GeminiCompletedAt = DateTime.UtcNow; break; case "replicate": request.ReplicateStatus = "failed"; request.ReplicateErrorMessage = errorMessage; request.ReplicateCompletedAt = DateTime.UtcNow; break; default: request.GeminiErrorMessage = errorMessage; request.ReplicateErrorMessage = errorMessage; break; } request.CompletedAt = DateTime.UtcNow; await _dbContext.SaveChangesAsync(); _logger.LogError("Generation request {RequestId} marked as failed at stage {Stage}: {Error}", requestId, stage, errorMessage); } private static string ComputeHash(byte[] bytes) { using var sha256 = System.Security.Cryptography.SHA256.Create(); var hashBytes = sha256.ComputeHash(bytes); return Convert.ToHexString(hashBytes); } }