386 lines
15 KiB
C#
386 lines
15 KiB
C#
using DramaLing.Api.Data;
|
||
using DramaLing.Api.Models.DTOs;
|
||
using DramaLing.Api.Models.Entities;
|
||
using DramaLing.Api.Services.AI;
|
||
using DramaLing.Api.Services.Storage;
|
||
using Microsoft.EntityFrameworkCore;
|
||
using System.Diagnostics;
|
||
using System.Text.Json;
|
||
|
||
namespace DramaLing.Api.Services;
|
||
|
||
public class ImageGenerationOrchestrator : IImageGenerationOrchestrator
|
||
{
|
||
private readonly IGeminiImageDescriptionService _geminiService;
|
||
private readonly IReplicateImageGenerationService _replicateService;
|
||
private readonly IImageStorageService _storageService;
|
||
private readonly DramaLingDbContext _dbContext;
|
||
private readonly ILogger<ImageGenerationOrchestrator> _logger;
|
||
|
||
public ImageGenerationOrchestrator(
|
||
IGeminiImageDescriptionService geminiService,
|
||
IReplicateImageGenerationService replicateService,
|
||
IImageStorageService storageService,
|
||
DramaLingDbContext dbContext,
|
||
ILogger<ImageGenerationOrchestrator> logger)
|
||
{
|
||
_geminiService = geminiService ?? throw new ArgumentNullException(nameof(geminiService));
|
||
_replicateService = replicateService ?? throw new ArgumentNullException(nameof(replicateService));
|
||
_storageService = storageService ?? throw new ArgumentNullException(nameof(storageService));
|
||
_dbContext = dbContext ?? throw new ArgumentNullException(nameof(dbContext));
|
||
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
|
||
}
|
||
|
||
public async Task<GenerationRequestResult> StartGenerationAsync(Guid flashcardId, GenerationRequest request)
|
||
{
|
||
try
|
||
{
|
||
// 檢查詞卡是否存在
|
||
var flashcard = await _dbContext.Flashcards.FindAsync(flashcardId);
|
||
if (flashcard == null)
|
||
{
|
||
throw new ArgumentException($"Flashcard {flashcardId} not found");
|
||
}
|
||
|
||
// 建立生成請求記錄
|
||
var generationRequest = new ImageGenerationRequest
|
||
{
|
||
Id = Guid.NewGuid(),
|
||
UserId = request.UserId,
|
||
FlashcardId = flashcardId,
|
||
OverallStatus = "pending",
|
||
GeminiStatus = "pending",
|
||
ReplicateStatus = "pending",
|
||
OriginalRequest = JsonSerializer.Serialize(request),
|
||
CreatedAt = DateTime.UtcNow
|
||
};
|
||
|
||
_dbContext.ImageGenerationRequests.Add(generationRequest);
|
||
await _dbContext.SaveChangesAsync();
|
||
|
||
_logger.LogInformation("Created generation request {RequestId} for flashcard {FlashcardId}",
|
||
generationRequest.Id, flashcardId);
|
||
|
||
// 後台執行兩階段生成流程
|
||
_ = Task.Run(async () => await ExecuteGenerationPipelineAsync(generationRequest.Id));
|
||
|
||
return new GenerationRequestResult
|
||
{
|
||
RequestId = generationRequest.Id,
|
||
OverallStatus = "pending",
|
||
CurrentStage = "description_generation",
|
||
EstimatedTimeMinutes = new EstimatedTimeDto
|
||
{
|
||
Gemini = 0.5,
|
||
Replicate = 2.0,
|
||
Total = 2.5
|
||
},
|
||
CostEstimate = new CostEstimateDto
|
||
{
|
||
Gemini = 0.002m,
|
||
Replicate = 0.025m,
|
||
Total = 0.027m
|
||
}
|
||
};
|
||
}
|
||
catch (Exception ex)
|
||
{
|
||
_logger.LogError(ex, "Failed to start generation for flashcard {FlashcardId}", flashcardId);
|
||
throw;
|
||
}
|
||
}
|
||
|
||
public async Task<GenerationStatusResponse> GetGenerationStatusAsync(Guid requestId)
|
||
{
|
||
var request = await _dbContext.ImageGenerationRequests
|
||
.Include(r => r.GeneratedImage)
|
||
.FirstOrDefaultAsync(r => r.Id == requestId);
|
||
|
||
if (request == null)
|
||
{
|
||
throw new ArgumentException($"Generation request {requestId} not found");
|
||
}
|
||
|
||
return new GenerationStatusResponse
|
||
{
|
||
RequestId = request.Id,
|
||
OverallStatus = request.OverallStatus,
|
||
Stages = new StageStatusDto
|
||
{
|
||
Gemini = new GeminiStageDto
|
||
{
|
||
Status = request.GeminiStatus,
|
||
StartedAt = request.GeminiStartedAt,
|
||
CompletedAt = request.GeminiCompletedAt,
|
||
ProcessingTimeMs = request.GeminiProcessingTimeMs,
|
||
Cost = request.GeminiCost,
|
||
GeneratedDescription = request.GeneratedDescription
|
||
},
|
||
Replicate = new ReplicateStageDto
|
||
{
|
||
Status = request.ReplicateStatus,
|
||
StartedAt = request.ReplicateStartedAt,
|
||
CompletedAt = request.ReplicateCompletedAt,
|
||
ProcessingTimeMs = request.ReplicateProcessingTimeMs,
|
||
Cost = request.ReplicateCost
|
||
}
|
||
},
|
||
TotalCost = request.TotalCost,
|
||
CompletedAt = request.CompletedAt,
|
||
Result = request.GeneratedImage != null ? new GenerationResultDto
|
||
{
|
||
ImageUrl = await _storageService.GetImageUrlAsync(request.GeneratedImage.RelativePath),
|
||
ImageId = request.GeneratedImage.Id.ToString(),
|
||
QualityScore = request.GeneratedImage.QualityScore,
|
||
Dimensions = new DimensionsDto
|
||
{
|
||
Width = request.GeneratedImage.ImageWidth ?? 512,
|
||
Height = request.GeneratedImage.ImageHeight ?? 512
|
||
},
|
||
FileSize = request.GeneratedImage.FileSize
|
||
} : null
|
||
};
|
||
}
|
||
|
||
public async Task<bool> CancelGenerationAsync(Guid requestId)
|
||
{
|
||
try
|
||
{
|
||
var request = await _dbContext.ImageGenerationRequests.FindAsync(requestId);
|
||
if (request == null || request.OverallStatus == "completed")
|
||
{
|
||
return false;
|
||
}
|
||
|
||
request.OverallStatus = "cancelled";
|
||
await _dbContext.SaveChangesAsync();
|
||
|
||
_logger.LogInformation("Generation request {RequestId} cancelled", requestId);
|
||
return true;
|
||
}
|
||
catch (Exception ex)
|
||
{
|
||
_logger.LogError(ex, "Failed to cancel generation request {RequestId}", requestId);
|
||
return false;
|
||
}
|
||
}
|
||
|
||
private async Task ExecuteGenerationPipelineAsync(Guid requestId)
|
||
{
|
||
var totalStopwatch = Stopwatch.StartNew();
|
||
|
||
try
|
||
{
|
||
var request = await _dbContext.ImageGenerationRequests
|
||
.Include(r => r.Flashcard)
|
||
.FirstOrDefaultAsync(r => r.Id == requestId);
|
||
|
||
if (request == null)
|
||
{
|
||
_logger.LogError("Generation request {RequestId} not found in pipeline", requestId);
|
||
return;
|
||
}
|
||
|
||
var options = JsonSerializer.Deserialize<GenerationRequest>(request.OriginalRequest);
|
||
|
||
// 第一階段:Gemini 描述生成
|
||
_logger.LogInformation("Starting Gemini description generation for request {RequestId}", requestId);
|
||
|
||
await UpdateRequestStatusAsync(requestId, "description_generating", "processing", "pending");
|
||
|
||
var descriptionResult = await _geminiService.GenerateDescriptionAsync(
|
||
request.Flashcard,
|
||
options?.Options ?? new GenerationOptionsDto());
|
||
|
||
if (!descriptionResult.Success)
|
||
{
|
||
await MarkRequestAsFailedAsync(requestId, "gemini", descriptionResult.Error);
|
||
return;
|
||
}
|
||
|
||
// 更新 Gemini 結果
|
||
await UpdateGeminiResultAsync(requestId, descriptionResult);
|
||
|
||
// 第二階段:Replicate 圖片生成
|
||
_logger.LogInformation("Starting Replicate image generation for request {RequestId}", requestId);
|
||
|
||
await UpdateRequestStatusAsync(requestId, "image_generating", "completed", "processing");
|
||
|
||
var imageResult = await _replicateService.GenerateImageAsync(
|
||
descriptionResult.OptimizedPrompt ?? descriptionResult.Description ?? "",
|
||
options?.ReplicateModel ?? "ideogram-v2a-turbo",
|
||
options?.Options ?? new GenerationOptionsDto());
|
||
|
||
if (!imageResult.Success)
|
||
{
|
||
await MarkRequestAsFailedAsync(requestId, "replicate", imageResult.Error);
|
||
return;
|
||
}
|
||
|
||
// 下載並儲存圖片
|
||
var savedImage = await SaveGeneratedImageAsync(request, descriptionResult, imageResult);
|
||
|
||
// 完成請求
|
||
await CompleteRequestAsync(requestId, savedImage.Id, totalStopwatch.ElapsedMilliseconds);
|
||
|
||
_logger.LogInformation("Generation pipeline completed successfully for request {RequestId} in {ElapsedMs}ms",
|
||
requestId, totalStopwatch.ElapsedMilliseconds);
|
||
}
|
||
catch (Exception ex)
|
||
{
|
||
totalStopwatch.Stop();
|
||
_logger.LogError(ex, "Generation pipeline failed for request {RequestId}", requestId);
|
||
await MarkRequestAsFailedAsync(requestId, "system", ex.Message);
|
||
}
|
||
}
|
||
|
||
private async Task UpdateRequestStatusAsync(Guid requestId, string overallStatus, string geminiStatus, string replicateStatus)
|
||
{
|
||
var request = await _dbContext.ImageGenerationRequests.FindAsync(requestId);
|
||
if (request == null) return;
|
||
|
||
request.OverallStatus = overallStatus;
|
||
request.GeminiStatus = geminiStatus;
|
||
request.ReplicateStatus = replicateStatus;
|
||
|
||
if (geminiStatus == "processing" && request.GeminiStartedAt == null)
|
||
{
|
||
request.GeminiStartedAt = DateTime.UtcNow;
|
||
}
|
||
|
||
if (replicateStatus == "processing" && request.ReplicateStartedAt == null)
|
||
{
|
||
request.ReplicateStartedAt = DateTime.UtcNow;
|
||
}
|
||
|
||
await _dbContext.SaveChangesAsync();
|
||
}
|
||
|
||
private async Task UpdateGeminiResultAsync(Guid requestId, ImageDescriptionResult result)
|
||
{
|
||
var request = await _dbContext.ImageGenerationRequests.FindAsync(requestId);
|
||
if (request == null) return;
|
||
|
||
request.GeminiStatus = "completed";
|
||
request.GeminiCompletedAt = DateTime.UtcNow;
|
||
request.GeneratedDescription = result.Description;
|
||
request.FinalReplicatePrompt = result.OptimizedPrompt;
|
||
request.GeminiCost = result.Cost;
|
||
request.GeminiProcessingTimeMs = result.ProcessingTimeMs;
|
||
|
||
await _dbContext.SaveChangesAsync();
|
||
}
|
||
|
||
private async Task<ExampleImage> SaveGeneratedImageAsync(
|
||
ImageGenerationRequest request,
|
||
ImageDescriptionResult descriptionResult,
|
||
ImageGenerationResult imageResult)
|
||
{
|
||
// 下載圖片
|
||
using var httpClient = new HttpClient();
|
||
var imageBytes = await httpClient.GetByteArrayAsync(imageResult.ImageUrl);
|
||
var imageStream = new MemoryStream(imageBytes);
|
||
|
||
// 生成檔案名稱
|
||
var fileName = $"{request.FlashcardId}_{Guid.NewGuid()}.png";
|
||
|
||
// 儲存到本地/雲端
|
||
var relativePath = await _storageService.SaveImageAsync(imageStream, fileName);
|
||
|
||
// 建立 ExampleImage 記錄
|
||
var exampleImage = new ExampleImage
|
||
{
|
||
Id = Guid.NewGuid(),
|
||
RelativePath = relativePath,
|
||
AltText = $"Example image for {request.Flashcard?.Word}",
|
||
GeminiPrompt = request.GeminiPrompt,
|
||
GeminiDescription = descriptionResult.Description,
|
||
ReplicatePrompt = descriptionResult.OptimizedPrompt,
|
||
ReplicateModel = "ideogram-v2a-turbo",
|
||
GeminiCost = descriptionResult.Cost,
|
||
ReplicateCost = imageResult.Cost,
|
||
TotalGenerationCost = descriptionResult.Cost + imageResult.Cost,
|
||
FileSize = imageBytes.Length,
|
||
ImageWidth = 512,
|
||
ImageHeight = 512,
|
||
ContentHash = ComputeHash(imageBytes),
|
||
ModerationStatus = "pending",
|
||
CreatedAt = DateTime.UtcNow,
|
||
UpdatedAt = DateTime.UtcNow
|
||
};
|
||
|
||
_dbContext.ExampleImages.Add(exampleImage);
|
||
|
||
// 建立詞卡圖片關聯
|
||
var flashcardImage = new FlashcardExampleImage
|
||
{
|
||
FlashcardId = request.FlashcardId,
|
||
ExampleImageId = exampleImage.Id,
|
||
DisplayOrder = 1,
|
||
IsPrimary = true,
|
||
ContextRelevance = 1.0m,
|
||
CreatedAt = DateTime.UtcNow
|
||
};
|
||
|
||
_dbContext.FlashcardExampleImages.Add(flashcardImage);
|
||
await _dbContext.SaveChangesAsync();
|
||
|
||
return exampleImage;
|
||
}
|
||
|
||
private async Task CompleteRequestAsync(Guid requestId, Guid imageId, long totalProcessingTimeMs)
|
||
{
|
||
var request = await _dbContext.ImageGenerationRequests.FindAsync(requestId);
|
||
if (request == null) return;
|
||
|
||
request.OverallStatus = "completed";
|
||
request.ReplicateStatus = "completed";
|
||
request.GeneratedImageId = imageId;
|
||
request.CompletedAt = DateTime.UtcNow;
|
||
request.ReplicateCompletedAt = DateTime.UtcNow;
|
||
request.TotalProcessingTimeMs = (int)totalProcessingTimeMs;
|
||
request.TotalCost = (request.GeminiCost ?? 0) + (request.ReplicateCost ?? 0);
|
||
|
||
await _dbContext.SaveChangesAsync();
|
||
}
|
||
|
||
private async Task MarkRequestAsFailedAsync(Guid requestId, string stage, string? errorMessage)
|
||
{
|
||
var request = await _dbContext.ImageGenerationRequests.FindAsync(requestId);
|
||
if (request == null) return;
|
||
|
||
request.OverallStatus = "failed";
|
||
|
||
switch (stage.ToLower())
|
||
{
|
||
case "gemini":
|
||
request.GeminiStatus = "failed";
|
||
request.GeminiErrorMessage = errorMessage;
|
||
request.GeminiCompletedAt = DateTime.UtcNow;
|
||
break;
|
||
case "replicate":
|
||
request.ReplicateStatus = "failed";
|
||
request.ReplicateErrorMessage = errorMessage;
|
||
request.ReplicateCompletedAt = DateTime.UtcNow;
|
||
break;
|
||
default:
|
||
request.GeminiErrorMessage = errorMessage;
|
||
request.ReplicateErrorMessage = errorMessage;
|
||
break;
|
||
}
|
||
|
||
request.CompletedAt = DateTime.UtcNow;
|
||
|
||
await _dbContext.SaveChangesAsync();
|
||
|
||
_logger.LogError("Generation request {RequestId} marked as failed at stage {Stage}: {Error}",
|
||
requestId, stage, errorMessage);
|
||
}
|
||
|
||
private static string ComputeHash(byte[] bytes)
|
||
{
|
||
using var sha256 = System.Security.Cryptography.SHA256.Create();
|
||
var hashBytes = sha256.ComputeHash(bytes);
|
||
return Convert.ToHexString(hashBytes);
|
||
}
|
||
} |