dramaling-vocab-learning/backend/DramaLing.Api/Services/AI/Generation/ImageGenerationOrchestrator.cs

425 lines
17 KiB
C#
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

using DramaLing.Api.Data;
using DramaLing.Api.Models.DTOs;
using DramaLing.Api.Models.Entities;
// Services.AI namespace removed
using DramaLing.Api.Services;
using DramaLing.Api.Services.Storage;
using Microsoft.EntityFrameworkCore;
using System.Diagnostics;
using System.Text.Json;
namespace DramaLing.Api.Services;
public class ImageGenerationOrchestrator : IImageGenerationOrchestrator
{
private readonly IServiceProvider _serviceProvider;
private readonly ILogger<ImageGenerationOrchestrator> _logger;
public ImageGenerationOrchestrator(
IServiceProvider serviceProvider,
ILogger<ImageGenerationOrchestrator> logger)
{
_serviceProvider = serviceProvider ?? throw new ArgumentNullException(nameof(serviceProvider));
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
}
public async Task<GenerationRequestResult> StartGenerationAsync(Guid flashcardId, GenerationRequest request)
{
using var scope = _serviceProvider.CreateScope();
var dbContext = scope.ServiceProvider.GetRequiredService<DramaLingDbContext>();
try
{
// 檢查詞卡是否存在
var flashcard = await dbContext.Flashcards.FindAsync(flashcardId);
if (flashcard == null)
{
throw new ArgumentException($"Flashcard {flashcardId} not found");
}
// 建立生成請求記錄
var generationRequest = new ImageGenerationRequest
{
Id = Guid.NewGuid(),
UserId = request.UserId,
FlashcardId = flashcardId,
OverallStatus = "pending",
GeminiStatus = "pending",
ReplicateStatus = "pending",
OriginalRequest = JsonSerializer.Serialize(request),
CreatedAt = DateTime.UtcNow
};
dbContext.ImageGenerationRequests.Add(generationRequest);
await dbContext.SaveChangesAsync();
_logger.LogInformation("Created generation request {RequestId} for flashcard {FlashcardId}",
generationRequest.Id, flashcardId);
// 後台執行兩階段生成流程 - 使用獨立的 scope
_ = Task.Run(async () =>
{
try
{
await ExecuteGenerationPipelineAsync(generationRequest.Id);
}
catch (Exception ex)
{
_logger.LogError(ex, "Background generation pipeline failed for request {RequestId}", generationRequest.Id);
}
});
return new GenerationRequestResult
{
RequestId = generationRequest.Id,
OverallStatus = "pending",
CurrentStage = "description_generation",
EstimatedTimeMinutes = new EstimatedTimeDto
{
Gemini = 0.5,
Replicate = 2.0,
Total = 2.5
},
CostEstimate = new CostEstimateDto
{
Gemini = 0.002m,
Replicate = 0.025m,
Total = 0.027m
}
};
}
catch (Exception ex)
{
_logger.LogError(ex, "Failed to start generation for flashcard {FlashcardId}", flashcardId);
throw;
}
}
public async Task<GenerationStatusResponse> GetGenerationStatusAsync(Guid requestId)
{
using var scope = _serviceProvider.CreateScope();
var dbContext = scope.ServiceProvider.GetRequiredService<DramaLingDbContext>();
var storageService = scope.ServiceProvider.GetRequiredService<IImageStorageService>();
var request = await dbContext.ImageGenerationRequests
.Include(r => r.GeneratedImage)
.FirstOrDefaultAsync(r => r.Id == requestId);
if (request == null)
{
throw new ArgumentException($"Generation request {requestId} not found");
}
return new GenerationStatusResponse
{
RequestId = request.Id,
OverallStatus = request.OverallStatus,
Stages = new StageStatusDto
{
Gemini = new GeminiStageDto
{
Status = request.GeminiStatus,
StartedAt = request.GeminiStartedAt,
CompletedAt = request.GeminiCompletedAt,
ProcessingTimeMs = request.GeminiProcessingTimeMs,
Cost = request.GeminiCost,
GeneratedDescription = request.GeneratedDescription
},
Replicate = new ReplicateStageDto
{
Status = request.ReplicateStatus,
StartedAt = request.ReplicateStartedAt,
CompletedAt = request.ReplicateCompletedAt,
ProcessingTimeMs = request.ReplicateProcessingTimeMs,
Cost = request.ReplicateCost
}
},
TotalCost = request.TotalCost,
CompletedAt = request.CompletedAt,
Result = request.GeneratedImage != null ? new GenerationResultDto
{
ImageUrl = await storageService.GetImageUrlAsync(request.GeneratedImage.RelativePath),
ImageId = request.GeneratedImage.Id.ToString(),
QualityScore = request.GeneratedImage.QualityScore,
Dimensions = new DimensionsDto
{
Width = request.GeneratedImage.ImageWidth ?? 512,
Height = request.GeneratedImage.ImageHeight ?? 512
},
FileSize = request.GeneratedImage.FileSize
} : null
};
}
public async Task<bool> CancelGenerationAsync(Guid requestId)
{
using var scope = _serviceProvider.CreateScope();
var dbContext = scope.ServiceProvider.GetRequiredService<DramaLingDbContext>();
try
{
var request = await dbContext.ImageGenerationRequests.FindAsync(requestId);
if (request == null || request.OverallStatus == "completed")
{
return false;
}
request.OverallStatus = "cancelled";
await dbContext.SaveChangesAsync();
_logger.LogInformation("Generation request {RequestId} cancelled", requestId);
return true;
}
catch (Exception ex)
{
_logger.LogError(ex, "Failed to cancel generation request {RequestId}", requestId);
return false;
}
}
private async Task ExecuteGenerationPipelineAsync(Guid requestId)
{
var totalStopwatch = Stopwatch.StartNew();
// 使用獨立的 scope 避免 DbContext 生命週期問題
using var scope = _serviceProvider.CreateScope();
var dbContext = scope.ServiceProvider.GetRequiredService<DramaLingDbContext>();
var geminiService = scope.ServiceProvider.GetRequiredService<IGeminiService>();
var replicateService = scope.ServiceProvider.GetRequiredService<IReplicateService>();
var storageService = scope.ServiceProvider.GetRequiredService<IImageStorageService>();
var imageProcessingService = scope.ServiceProvider.GetRequiredService<IImageProcessingService>();
try
{
_logger.LogInformation("Starting generation pipeline for request {RequestId}", requestId);
var request = await dbContext.ImageGenerationRequests
.Include(r => r.Flashcard)
.FirstOrDefaultAsync(r => r.Id == requestId);
if (request == null)
{
_logger.LogError("Generation request {RequestId} not found in pipeline", requestId);
return;
}
var options = JsonSerializer.Deserialize<GenerationRequest>(request.OriginalRequest);
// 第一階段Gemini 描述生成
_logger.LogInformation("Starting Gemini description generation for request {RequestId}", requestId);
await UpdateRequestStatusAsync(dbContext, requestId, "description_generating", "processing", "pending");
var optimizedPrompt = await geminiService.GenerateImageDescriptionAsync(
request.Flashcard,
options?.Options ?? new GenerationOptionsDto());
if (string.IsNullOrWhiteSpace(optimizedPrompt))
{
await MarkRequestAsFailedAsync(dbContext, requestId, "gemini", "Generated prompt is empty");
return;
}
// 更新 Gemini 結果
await UpdateGeminiResultAsync(dbContext, requestId, optimizedPrompt);
// 第二階段Replicate 圖片生成
_logger.LogInformation("Starting Replicate image generation for request {RequestId}", requestId);
await UpdateRequestStatusAsync(dbContext, requestId, "image_generating", "completed", "processing");
// 強制使用正確的模型名稱,避免參數傳遞錯誤
var modelName = "ideogram-v2a-turbo";
_logger.LogInformation("Using Replicate model: {ModelName}", modelName);
var imageResult = await replicateService.GenerateImageAsync(
optimizedPrompt,
modelName,
new ReplicateGenerationOptions
{
Width = options?.Width ?? 512,
Height = options?.Height ?? 512,
TimeoutMinutes = 5
});
if (!imageResult.Success)
{
await MarkRequestAsFailedAsync(dbContext, requestId, "replicate", imageResult.Error);
return;
}
// 下載並儲存圖片
var savedImage = await SaveGeneratedImageAsync(dbContext, storageService, imageProcessingService, request, optimizedPrompt, imageResult);
// 完成請求
await CompleteRequestAsync(dbContext, requestId, savedImage.Id, totalStopwatch.ElapsedMilliseconds);
_logger.LogInformation("Generation pipeline completed successfully for request {RequestId} in {ElapsedMs}ms",
requestId, totalStopwatch.ElapsedMilliseconds);
}
catch (Exception ex)
{
totalStopwatch.Stop();
_logger.LogError(ex, "Generation pipeline failed for request {RequestId}", requestId);
await MarkRequestAsFailedAsync(dbContext, requestId, "system", ex.Message);
}
}
private async Task UpdateRequestStatusAsync(DramaLingDbContext dbContext, Guid requestId, string overallStatus, string geminiStatus, string replicateStatus)
{
var request = await dbContext.ImageGenerationRequests.FindAsync(requestId);
if (request == null) return;
request.OverallStatus = overallStatus;
request.GeminiStatus = geminiStatus;
request.ReplicateStatus = replicateStatus;
if (geminiStatus == "processing" && request.GeminiStartedAt == null)
{
request.GeminiStartedAt = DateTime.UtcNow;
}
if (replicateStatus == "processing" && request.ReplicateStartedAt == null)
{
request.ReplicateStartedAt = DateTime.UtcNow;
}
await dbContext.SaveChangesAsync();
}
private async Task UpdateGeminiResultAsync(DramaLingDbContext dbContext, Guid requestId, string optimizedPrompt)
{
var request = await dbContext.ImageGenerationRequests.FindAsync(requestId);
if (request == null) return;
request.GeminiStatus = "completed";
request.GeminiCompletedAt = DateTime.UtcNow;
request.GeneratedDescription = "Gemini generated description"; // 簡化版本
request.FinalReplicatePrompt = optimizedPrompt;
request.GeminiCost = 0.002m; // 預設成本
request.GeminiProcessingTimeMs = 30000; // 預設時間
await dbContext.SaveChangesAsync();
}
private async Task<ExampleImage> SaveGeneratedImageAsync(
DramaLingDbContext dbContext,
IImageStorageService storageService,
IImageProcessingService imageProcessingService,
ImageGenerationRequest request,
string optimizedPrompt,
ReplicateImageResult imageResult)
{
// 下載原圖 (1024x1024)
using var httpClient = new HttpClient();
var originalBytes = await httpClient.GetByteArrayAsync(imageResult.ImageUrl);
_logger.LogInformation("Downloaded original image: {OriginalSize}KB", originalBytes.Length / 1024);
// 壓縮為 512x512
var resizedBytes = await imageProcessingService.ResizeImageAsync(originalBytes, 512, 512);
var imageStream = new MemoryStream(resizedBytes);
// 生成檔案名稱
var fileName = $"{request.FlashcardId}_{Guid.NewGuid()}.png";
// 儲存到本地/雲端
var relativePath = await storageService.SaveImageAsync(imageStream, fileName);
// 建立 ExampleImage 記錄
var exampleImage = new ExampleImage
{
Id = Guid.NewGuid(),
RelativePath = relativePath,
AltText = $"Example image for {request.Flashcard?.Word}",
GeminiPrompt = request.GeminiPrompt,
GeminiDescription = request.GeneratedDescription,
ReplicatePrompt = optimizedPrompt,
ReplicateModel = "ideogram-v2a-turbo",
GeminiCost = request.GeminiCost ?? 0.002m,
ReplicateCost = imageResult.Cost,
TotalGenerationCost = (request.GeminiCost ?? 0.002m) + imageResult.Cost,
FileSize = resizedBytes.Length, // 使用壓縮後的檔案大小
ImageWidth = 512,
ImageHeight = 512,
ContentHash = ComputeHash(resizedBytes), // 使用壓縮後的檔案計算 hash
ModerationStatus = "pending",
CreatedAt = DateTime.UtcNow,
UpdatedAt = DateTime.UtcNow
};
dbContext.ExampleImages.Add(exampleImage);
// 建立詞卡圖片關聯
var flashcardImage = new FlashcardExampleImage
{
FlashcardId = request.FlashcardId,
ExampleImageId = exampleImage.Id,
DisplayOrder = 1,
IsPrimary = true,
ContextRelevance = 1.0m,
CreatedAt = DateTime.UtcNow
};
dbContext.FlashcardExampleImages.Add(flashcardImage);
await dbContext.SaveChangesAsync();
return exampleImage;
}
private async Task CompleteRequestAsync(DramaLingDbContext dbContext, Guid requestId, Guid imageId, long totalProcessingTimeMs)
{
var request = await dbContext.ImageGenerationRequests.FindAsync(requestId);
if (request == null) return;
request.OverallStatus = "completed";
request.ReplicateStatus = "completed";
request.GeneratedImageId = imageId;
request.CompletedAt = DateTime.UtcNow;
request.ReplicateCompletedAt = DateTime.UtcNow;
request.TotalProcessingTimeMs = (int)totalProcessingTimeMs;
request.TotalCost = (request.GeminiCost ?? 0) + (request.ReplicateCost ?? 0);
await dbContext.SaveChangesAsync();
}
private async Task MarkRequestAsFailedAsync(DramaLingDbContext dbContext, Guid requestId, string stage, string? errorMessage)
{
var request = await dbContext.ImageGenerationRequests.FindAsync(requestId);
if (request == null) return;
request.OverallStatus = "failed";
switch (stage.ToLower())
{
case "gemini":
request.GeminiStatus = "failed";
request.GeminiErrorMessage = errorMessage;
request.GeminiCompletedAt = DateTime.UtcNow;
break;
case "replicate":
request.ReplicateStatus = "failed";
request.ReplicateErrorMessage = errorMessage;
request.ReplicateCompletedAt = DateTime.UtcNow;
break;
default:
request.GeminiErrorMessage = errorMessage;
request.ReplicateErrorMessage = errorMessage;
break;
}
request.CompletedAt = DateTime.UtcNow;
await dbContext.SaveChangesAsync();
_logger.LogError("Generation request {RequestId} marked as failed at stage {Stage}: {Error}",
requestId, stage, errorMessage);
}
private static string ComputeHash(byte[] bytes)
{
using var sha256 = System.Security.Cryptography.SHA256.Create();
var hashBytes = sha256.ComputeHash(bytes);
return Convert.ToHexString(hashBytes);
}
}