dramaling-vocab-learning/backend/DramaLing.Api/Services/ImageGenerationOrchestrator.cs

386 lines
15 KiB
C#
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

using DramaLing.Api.Data;
using DramaLing.Api.Models.DTOs;
using DramaLing.Api.Models.Entities;
using DramaLing.Api.Services.AI;
using DramaLing.Api.Services.Storage;
using Microsoft.EntityFrameworkCore;
using System.Diagnostics;
using System.Text.Json;
namespace DramaLing.Api.Services;
public class ImageGenerationOrchestrator : IImageGenerationOrchestrator
{
private readonly IGeminiImageDescriptionService _geminiService;
private readonly IReplicateImageGenerationService _replicateService;
private readonly IImageStorageService _storageService;
private readonly DramaLingDbContext _dbContext;
private readonly ILogger<ImageGenerationOrchestrator> _logger;
public ImageGenerationOrchestrator(
IGeminiImageDescriptionService geminiService,
IReplicateImageGenerationService replicateService,
IImageStorageService storageService,
DramaLingDbContext dbContext,
ILogger<ImageGenerationOrchestrator> logger)
{
_geminiService = geminiService ?? throw new ArgumentNullException(nameof(geminiService));
_replicateService = replicateService ?? throw new ArgumentNullException(nameof(replicateService));
_storageService = storageService ?? throw new ArgumentNullException(nameof(storageService));
_dbContext = dbContext ?? throw new ArgumentNullException(nameof(dbContext));
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
}
public async Task<GenerationRequestResult> StartGenerationAsync(Guid flashcardId, GenerationRequest request)
{
try
{
// 檢查詞卡是否存在
var flashcard = await _dbContext.Flashcards.FindAsync(flashcardId);
if (flashcard == null)
{
throw new ArgumentException($"Flashcard {flashcardId} not found");
}
// 建立生成請求記錄
var generationRequest = new ImageGenerationRequest
{
Id = Guid.NewGuid(),
UserId = request.UserId,
FlashcardId = flashcardId,
OverallStatus = "pending",
GeminiStatus = "pending",
ReplicateStatus = "pending",
OriginalRequest = JsonSerializer.Serialize(request),
CreatedAt = DateTime.UtcNow
};
_dbContext.ImageGenerationRequests.Add(generationRequest);
await _dbContext.SaveChangesAsync();
_logger.LogInformation("Created generation request {RequestId} for flashcard {FlashcardId}",
generationRequest.Id, flashcardId);
// 後台執行兩階段生成流程
_ = Task.Run(async () => await ExecuteGenerationPipelineAsync(generationRequest.Id));
return new GenerationRequestResult
{
RequestId = generationRequest.Id,
OverallStatus = "pending",
CurrentStage = "description_generation",
EstimatedTimeMinutes = new EstimatedTimeDto
{
Gemini = 0.5,
Replicate = 2.0,
Total = 2.5
},
CostEstimate = new CostEstimateDto
{
Gemini = 0.002m,
Replicate = 0.025m,
Total = 0.027m
}
};
}
catch (Exception ex)
{
_logger.LogError(ex, "Failed to start generation for flashcard {FlashcardId}", flashcardId);
throw;
}
}
public async Task<GenerationStatusResponse> GetGenerationStatusAsync(Guid requestId)
{
var request = await _dbContext.ImageGenerationRequests
.Include(r => r.GeneratedImage)
.FirstOrDefaultAsync(r => r.Id == requestId);
if (request == null)
{
throw new ArgumentException($"Generation request {requestId} not found");
}
return new GenerationStatusResponse
{
RequestId = request.Id,
OverallStatus = request.OverallStatus,
Stages = new StageStatusDto
{
Gemini = new GeminiStageDto
{
Status = request.GeminiStatus,
StartedAt = request.GeminiStartedAt,
CompletedAt = request.GeminiCompletedAt,
ProcessingTimeMs = request.GeminiProcessingTimeMs,
Cost = request.GeminiCost,
GeneratedDescription = request.GeneratedDescription
},
Replicate = new ReplicateStageDto
{
Status = request.ReplicateStatus,
StartedAt = request.ReplicateStartedAt,
CompletedAt = request.ReplicateCompletedAt,
ProcessingTimeMs = request.ReplicateProcessingTimeMs,
Cost = request.ReplicateCost
}
},
TotalCost = request.TotalCost,
CompletedAt = request.CompletedAt,
Result = request.GeneratedImage != null ? new GenerationResultDto
{
ImageUrl = await _storageService.GetImageUrlAsync(request.GeneratedImage.RelativePath),
ImageId = request.GeneratedImage.Id.ToString(),
QualityScore = request.GeneratedImage.QualityScore,
Dimensions = new DimensionsDto
{
Width = request.GeneratedImage.ImageWidth ?? 512,
Height = request.GeneratedImage.ImageHeight ?? 512
},
FileSize = request.GeneratedImage.FileSize
} : null
};
}
public async Task<bool> CancelGenerationAsync(Guid requestId)
{
try
{
var request = await _dbContext.ImageGenerationRequests.FindAsync(requestId);
if (request == null || request.OverallStatus == "completed")
{
return false;
}
request.OverallStatus = "cancelled";
await _dbContext.SaveChangesAsync();
_logger.LogInformation("Generation request {RequestId} cancelled", requestId);
return true;
}
catch (Exception ex)
{
_logger.LogError(ex, "Failed to cancel generation request {RequestId}", requestId);
return false;
}
}
private async Task ExecuteGenerationPipelineAsync(Guid requestId)
{
var totalStopwatch = Stopwatch.StartNew();
try
{
var request = await _dbContext.ImageGenerationRequests
.Include(r => r.Flashcard)
.FirstOrDefaultAsync(r => r.Id == requestId);
if (request == null)
{
_logger.LogError("Generation request {RequestId} not found in pipeline", requestId);
return;
}
var options = JsonSerializer.Deserialize<GenerationRequest>(request.OriginalRequest);
// 第一階段Gemini 描述生成
_logger.LogInformation("Starting Gemini description generation for request {RequestId}", requestId);
await UpdateRequestStatusAsync(requestId, "description_generating", "processing", "pending");
var descriptionResult = await _geminiService.GenerateDescriptionAsync(
request.Flashcard,
options?.Options ?? new GenerationOptionsDto());
if (!descriptionResult.Success)
{
await MarkRequestAsFailedAsync(requestId, "gemini", descriptionResult.Error);
return;
}
// 更新 Gemini 結果
await UpdateGeminiResultAsync(requestId, descriptionResult);
// 第二階段Replicate 圖片生成
_logger.LogInformation("Starting Replicate image generation for request {RequestId}", requestId);
await UpdateRequestStatusAsync(requestId, "image_generating", "completed", "processing");
var imageResult = await _replicateService.GenerateImageAsync(
descriptionResult.OptimizedPrompt ?? descriptionResult.Description ?? "",
options?.ReplicateModel ?? "ideogram-v2a-turbo",
options?.Options ?? new GenerationOptionsDto());
if (!imageResult.Success)
{
await MarkRequestAsFailedAsync(requestId, "replicate", imageResult.Error);
return;
}
// 下載並儲存圖片
var savedImage = await SaveGeneratedImageAsync(request, descriptionResult, imageResult);
// 完成請求
await CompleteRequestAsync(requestId, savedImage.Id, totalStopwatch.ElapsedMilliseconds);
_logger.LogInformation("Generation pipeline completed successfully for request {RequestId} in {ElapsedMs}ms",
requestId, totalStopwatch.ElapsedMilliseconds);
}
catch (Exception ex)
{
totalStopwatch.Stop();
_logger.LogError(ex, "Generation pipeline failed for request {RequestId}", requestId);
await MarkRequestAsFailedAsync(requestId, "system", ex.Message);
}
}
private async Task UpdateRequestStatusAsync(Guid requestId, string overallStatus, string geminiStatus, string replicateStatus)
{
var request = await _dbContext.ImageGenerationRequests.FindAsync(requestId);
if (request == null) return;
request.OverallStatus = overallStatus;
request.GeminiStatus = geminiStatus;
request.ReplicateStatus = replicateStatus;
if (geminiStatus == "processing" && request.GeminiStartedAt == null)
{
request.GeminiStartedAt = DateTime.UtcNow;
}
if (replicateStatus == "processing" && request.ReplicateStartedAt == null)
{
request.ReplicateStartedAt = DateTime.UtcNow;
}
await _dbContext.SaveChangesAsync();
}
private async Task UpdateGeminiResultAsync(Guid requestId, ImageDescriptionResult result)
{
var request = await _dbContext.ImageGenerationRequests.FindAsync(requestId);
if (request == null) return;
request.GeminiStatus = "completed";
request.GeminiCompletedAt = DateTime.UtcNow;
request.GeneratedDescription = result.Description;
request.FinalReplicatePrompt = result.OptimizedPrompt;
request.GeminiCost = result.Cost;
request.GeminiProcessingTimeMs = result.ProcessingTimeMs;
await _dbContext.SaveChangesAsync();
}
private async Task<ExampleImage> SaveGeneratedImageAsync(
ImageGenerationRequest request,
ImageDescriptionResult descriptionResult,
ImageGenerationResult imageResult)
{
// 下載圖片
using var httpClient = new HttpClient();
var imageBytes = await httpClient.GetByteArrayAsync(imageResult.ImageUrl);
var imageStream = new MemoryStream(imageBytes);
// 生成檔案名稱
var fileName = $"{request.FlashcardId}_{Guid.NewGuid()}.png";
// 儲存到本地/雲端
var relativePath = await _storageService.SaveImageAsync(imageStream, fileName);
// 建立 ExampleImage 記錄
var exampleImage = new ExampleImage
{
Id = Guid.NewGuid(),
RelativePath = relativePath,
AltText = $"Example image for {request.Flashcard?.Word}",
GeminiPrompt = request.GeminiPrompt,
GeminiDescription = descriptionResult.Description,
ReplicatePrompt = descriptionResult.OptimizedPrompt,
ReplicateModel = "ideogram-v2a-turbo",
GeminiCost = descriptionResult.Cost,
ReplicateCost = imageResult.Cost,
TotalGenerationCost = descriptionResult.Cost + imageResult.Cost,
FileSize = imageBytes.Length,
ImageWidth = 512,
ImageHeight = 512,
ContentHash = ComputeHash(imageBytes),
ModerationStatus = "pending",
CreatedAt = DateTime.UtcNow,
UpdatedAt = DateTime.UtcNow
};
_dbContext.ExampleImages.Add(exampleImage);
// 建立詞卡圖片關聯
var flashcardImage = new FlashcardExampleImage
{
FlashcardId = request.FlashcardId,
ExampleImageId = exampleImage.Id,
DisplayOrder = 1,
IsPrimary = true,
ContextRelevance = 1.0m,
CreatedAt = DateTime.UtcNow
};
_dbContext.FlashcardExampleImages.Add(flashcardImage);
await _dbContext.SaveChangesAsync();
return exampleImage;
}
private async Task CompleteRequestAsync(Guid requestId, Guid imageId, long totalProcessingTimeMs)
{
var request = await _dbContext.ImageGenerationRequests.FindAsync(requestId);
if (request == null) return;
request.OverallStatus = "completed";
request.ReplicateStatus = "completed";
request.GeneratedImageId = imageId;
request.CompletedAt = DateTime.UtcNow;
request.ReplicateCompletedAt = DateTime.UtcNow;
request.TotalProcessingTimeMs = (int)totalProcessingTimeMs;
request.TotalCost = (request.GeminiCost ?? 0) + (request.ReplicateCost ?? 0);
await _dbContext.SaveChangesAsync();
}
private async Task MarkRequestAsFailedAsync(Guid requestId, string stage, string? errorMessage)
{
var request = await _dbContext.ImageGenerationRequests.FindAsync(requestId);
if (request == null) return;
request.OverallStatus = "failed";
switch (stage.ToLower())
{
case "gemini":
request.GeminiStatus = "failed";
request.GeminiErrorMessage = errorMessage;
request.GeminiCompletedAt = DateTime.UtcNow;
break;
case "replicate":
request.ReplicateStatus = "failed";
request.ReplicateErrorMessage = errorMessage;
request.ReplicateCompletedAt = DateTime.UtcNow;
break;
default:
request.GeminiErrorMessage = errorMessage;
request.ReplicateErrorMessage = errorMessage;
break;
}
request.CompletedAt = DateTime.UtcNow;
await _dbContext.SaveChangesAsync();
_logger.LogError("Generation request {RequestId} marked as failed at stage {Stage}: {Error}",
requestId, stage, errorMessage);
}
private static string ComputeHash(byte[] bytes)
{
using var sha256 = System.Security.Cryptography.SHA256.Create();
var hashBytes = sha256.ComputeHash(bytes);
return Convert.ToHexString(hashBytes);
}
}