425 lines
17 KiB
C#
425 lines
17 KiB
C#
using DramaLing.Api.Data;
|
||
using DramaLing.Api.Models.DTOs;
|
||
using DramaLing.Api.Models.Entities;
|
||
using DramaLing.Api.Services.AI;
|
||
using DramaLing.Api.Services;
|
||
using DramaLing.Api.Services.Storage;
|
||
using Microsoft.EntityFrameworkCore;
|
||
using System.Diagnostics;
|
||
using System.Text.Json;
|
||
|
||
namespace DramaLing.Api.Services;
|
||
|
||
public class ImageGenerationOrchestrator : IImageGenerationOrchestrator
|
||
{
|
||
private readonly IServiceProvider _serviceProvider;
|
||
private readonly ILogger<ImageGenerationOrchestrator> _logger;
|
||
|
||
public ImageGenerationOrchestrator(
|
||
IServiceProvider serviceProvider,
|
||
ILogger<ImageGenerationOrchestrator> logger)
|
||
{
|
||
_serviceProvider = serviceProvider ?? throw new ArgumentNullException(nameof(serviceProvider));
|
||
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
|
||
}
|
||
|
||
public async Task<GenerationRequestResult> StartGenerationAsync(Guid flashcardId, GenerationRequest request)
|
||
{
|
||
using var scope = _serviceProvider.CreateScope();
|
||
var dbContext = scope.ServiceProvider.GetRequiredService<DramaLingDbContext>();
|
||
|
||
try
|
||
{
|
||
// 檢查詞卡是否存在
|
||
var flashcard = await dbContext.Flashcards.FindAsync(flashcardId);
|
||
if (flashcard == null)
|
||
{
|
||
throw new ArgumentException($"Flashcard {flashcardId} not found");
|
||
}
|
||
|
||
// 建立生成請求記錄
|
||
var generationRequest = new ImageGenerationRequest
|
||
{
|
||
Id = Guid.NewGuid(),
|
||
UserId = request.UserId,
|
||
FlashcardId = flashcardId,
|
||
OverallStatus = "pending",
|
||
GeminiStatus = "pending",
|
||
ReplicateStatus = "pending",
|
||
OriginalRequest = JsonSerializer.Serialize(request),
|
||
CreatedAt = DateTime.UtcNow
|
||
};
|
||
|
||
dbContext.ImageGenerationRequests.Add(generationRequest);
|
||
await dbContext.SaveChangesAsync();
|
||
|
||
_logger.LogInformation("Created generation request {RequestId} for flashcard {FlashcardId}",
|
||
generationRequest.Id, flashcardId);
|
||
|
||
// 後台執行兩階段生成流程 - 使用獨立的 scope
|
||
_ = Task.Run(async () =>
|
||
{
|
||
try
|
||
{
|
||
await ExecuteGenerationPipelineAsync(generationRequest.Id);
|
||
}
|
||
catch (Exception ex)
|
||
{
|
||
_logger.LogError(ex, "Background generation pipeline failed for request {RequestId}", generationRequest.Id);
|
||
}
|
||
});
|
||
|
||
return new GenerationRequestResult
|
||
{
|
||
RequestId = generationRequest.Id,
|
||
OverallStatus = "pending",
|
||
CurrentStage = "description_generation",
|
||
EstimatedTimeMinutes = new EstimatedTimeDto
|
||
{
|
||
Gemini = 0.5,
|
||
Replicate = 2.0,
|
||
Total = 2.5
|
||
},
|
||
CostEstimate = new CostEstimateDto
|
||
{
|
||
Gemini = 0.002m,
|
||
Replicate = 0.025m,
|
||
Total = 0.027m
|
||
}
|
||
};
|
||
}
|
||
catch (Exception ex)
|
||
{
|
||
_logger.LogError(ex, "Failed to start generation for flashcard {FlashcardId}", flashcardId);
|
||
throw;
|
||
}
|
||
}
|
||
|
||
public async Task<GenerationStatusResponse> GetGenerationStatusAsync(Guid requestId)
|
||
{
|
||
using var scope = _serviceProvider.CreateScope();
|
||
var dbContext = scope.ServiceProvider.GetRequiredService<DramaLingDbContext>();
|
||
var storageService = scope.ServiceProvider.GetRequiredService<IImageStorageService>();
|
||
|
||
var request = await dbContext.ImageGenerationRequests
|
||
.Include(r => r.GeneratedImage)
|
||
.FirstOrDefaultAsync(r => r.Id == requestId);
|
||
|
||
if (request == null)
|
||
{
|
||
throw new ArgumentException($"Generation request {requestId} not found");
|
||
}
|
||
|
||
return new GenerationStatusResponse
|
||
{
|
||
RequestId = request.Id,
|
||
OverallStatus = request.OverallStatus,
|
||
Stages = new StageStatusDto
|
||
{
|
||
Gemini = new GeminiStageDto
|
||
{
|
||
Status = request.GeminiStatus,
|
||
StartedAt = request.GeminiStartedAt,
|
||
CompletedAt = request.GeminiCompletedAt,
|
||
ProcessingTimeMs = request.GeminiProcessingTimeMs,
|
||
Cost = request.GeminiCost,
|
||
GeneratedDescription = request.GeneratedDescription
|
||
},
|
||
Replicate = new ReplicateStageDto
|
||
{
|
||
Status = request.ReplicateStatus,
|
||
StartedAt = request.ReplicateStartedAt,
|
||
CompletedAt = request.ReplicateCompletedAt,
|
||
ProcessingTimeMs = request.ReplicateProcessingTimeMs,
|
||
Cost = request.ReplicateCost
|
||
}
|
||
},
|
||
TotalCost = request.TotalCost,
|
||
CompletedAt = request.CompletedAt,
|
||
Result = request.GeneratedImage != null ? new GenerationResultDto
|
||
{
|
||
ImageUrl = await storageService.GetImageUrlAsync(request.GeneratedImage.RelativePath),
|
||
ImageId = request.GeneratedImage.Id.ToString(),
|
||
QualityScore = request.GeneratedImage.QualityScore,
|
||
Dimensions = new DimensionsDto
|
||
{
|
||
Width = request.GeneratedImage.ImageWidth ?? 512,
|
||
Height = request.GeneratedImage.ImageHeight ?? 512
|
||
},
|
||
FileSize = request.GeneratedImage.FileSize
|
||
} : null
|
||
};
|
||
}
|
||
|
||
public async Task<bool> CancelGenerationAsync(Guid requestId)
|
||
{
|
||
using var scope = _serviceProvider.CreateScope();
|
||
var dbContext = scope.ServiceProvider.GetRequiredService<DramaLingDbContext>();
|
||
|
||
try
|
||
{
|
||
var request = await dbContext.ImageGenerationRequests.FindAsync(requestId);
|
||
if (request == null || request.OverallStatus == "completed")
|
||
{
|
||
return false;
|
||
}
|
||
|
||
request.OverallStatus = "cancelled";
|
||
await dbContext.SaveChangesAsync();
|
||
|
||
_logger.LogInformation("Generation request {RequestId} cancelled", requestId);
|
||
return true;
|
||
}
|
||
catch (Exception ex)
|
||
{
|
||
_logger.LogError(ex, "Failed to cancel generation request {RequestId}", requestId);
|
||
return false;
|
||
}
|
||
}
|
||
|
||
private async Task ExecuteGenerationPipelineAsync(Guid requestId)
|
||
{
|
||
var totalStopwatch = Stopwatch.StartNew();
|
||
|
||
// 使用獨立的 scope 避免 DbContext 生命週期問題
|
||
using var scope = _serviceProvider.CreateScope();
|
||
var dbContext = scope.ServiceProvider.GetRequiredService<DramaLingDbContext>();
|
||
var geminiService = scope.ServiceProvider.GetRequiredService<IGeminiService>();
|
||
var replicateService = scope.ServiceProvider.GetRequiredService<IReplicateService>();
|
||
var storageService = scope.ServiceProvider.GetRequiredService<IImageStorageService>();
|
||
var imageProcessingService = scope.ServiceProvider.GetRequiredService<IImageProcessingService>();
|
||
|
||
try
|
||
{
|
||
_logger.LogInformation("Starting generation pipeline for request {RequestId}", requestId);
|
||
|
||
var request = await dbContext.ImageGenerationRequests
|
||
.Include(r => r.Flashcard)
|
||
.FirstOrDefaultAsync(r => r.Id == requestId);
|
||
|
||
if (request == null)
|
||
{
|
||
_logger.LogError("Generation request {RequestId} not found in pipeline", requestId);
|
||
return;
|
||
}
|
||
|
||
var options = JsonSerializer.Deserialize<GenerationRequest>(request.OriginalRequest);
|
||
|
||
// 第一階段:Gemini 描述生成
|
||
_logger.LogInformation("Starting Gemini description generation for request {RequestId}", requestId);
|
||
|
||
await UpdateRequestStatusAsync(dbContext, requestId, "description_generating", "processing", "pending");
|
||
|
||
var optimizedPrompt = await geminiService.GenerateImageDescriptionAsync(
|
||
request.Flashcard,
|
||
options?.Options ?? new GenerationOptionsDto());
|
||
|
||
if (string.IsNullOrWhiteSpace(optimizedPrompt))
|
||
{
|
||
await MarkRequestAsFailedAsync(dbContext, requestId, "gemini", "Generated prompt is empty");
|
||
return;
|
||
}
|
||
|
||
// 更新 Gemini 結果
|
||
await UpdateGeminiResultAsync(dbContext, requestId, optimizedPrompt);
|
||
|
||
// 第二階段:Replicate 圖片生成
|
||
_logger.LogInformation("Starting Replicate image generation for request {RequestId}", requestId);
|
||
|
||
await UpdateRequestStatusAsync(dbContext, requestId, "image_generating", "completed", "processing");
|
||
|
||
// 強制使用正確的模型名稱,避免參數傳遞錯誤
|
||
var modelName = "ideogram-v2a-turbo";
|
||
_logger.LogInformation("Using Replicate model: {ModelName}", modelName);
|
||
|
||
var imageResult = await replicateService.GenerateImageAsync(
|
||
optimizedPrompt,
|
||
modelName,
|
||
new ReplicateGenerationOptions
|
||
{
|
||
Width = options?.Width ?? 512,
|
||
Height = options?.Height ?? 512,
|
||
TimeoutMinutes = 5
|
||
});
|
||
|
||
if (!imageResult.Success)
|
||
{
|
||
await MarkRequestAsFailedAsync(dbContext, requestId, "replicate", imageResult.Error);
|
||
return;
|
||
}
|
||
|
||
// 下載並儲存圖片
|
||
var savedImage = await SaveGeneratedImageAsync(dbContext, storageService, imageProcessingService, request, optimizedPrompt, imageResult);
|
||
|
||
// 完成請求
|
||
await CompleteRequestAsync(dbContext, requestId, savedImage.Id, totalStopwatch.ElapsedMilliseconds);
|
||
|
||
_logger.LogInformation("Generation pipeline completed successfully for request {RequestId} in {ElapsedMs}ms",
|
||
requestId, totalStopwatch.ElapsedMilliseconds);
|
||
}
|
||
catch (Exception ex)
|
||
{
|
||
totalStopwatch.Stop();
|
||
_logger.LogError(ex, "Generation pipeline failed for request {RequestId}", requestId);
|
||
await MarkRequestAsFailedAsync(dbContext, requestId, "system", ex.Message);
|
||
}
|
||
}
|
||
|
||
private async Task UpdateRequestStatusAsync(DramaLingDbContext dbContext, Guid requestId, string overallStatus, string geminiStatus, string replicateStatus)
|
||
{
|
||
var request = await dbContext.ImageGenerationRequests.FindAsync(requestId);
|
||
if (request == null) return;
|
||
|
||
request.OverallStatus = overallStatus;
|
||
request.GeminiStatus = geminiStatus;
|
||
request.ReplicateStatus = replicateStatus;
|
||
|
||
if (geminiStatus == "processing" && request.GeminiStartedAt == null)
|
||
{
|
||
request.GeminiStartedAt = DateTime.UtcNow;
|
||
}
|
||
|
||
if (replicateStatus == "processing" && request.ReplicateStartedAt == null)
|
||
{
|
||
request.ReplicateStartedAt = DateTime.UtcNow;
|
||
}
|
||
|
||
await dbContext.SaveChangesAsync();
|
||
}
|
||
|
||
private async Task UpdateGeminiResultAsync(DramaLingDbContext dbContext, Guid requestId, string optimizedPrompt)
|
||
{
|
||
var request = await dbContext.ImageGenerationRequests.FindAsync(requestId);
|
||
if (request == null) return;
|
||
|
||
request.GeminiStatus = "completed";
|
||
request.GeminiCompletedAt = DateTime.UtcNow;
|
||
request.GeneratedDescription = "Gemini generated description"; // 簡化版本
|
||
request.FinalReplicatePrompt = optimizedPrompt;
|
||
request.GeminiCost = 0.002m; // 預設成本
|
||
request.GeminiProcessingTimeMs = 30000; // 預設時間
|
||
|
||
await dbContext.SaveChangesAsync();
|
||
}
|
||
|
||
private async Task<ExampleImage> SaveGeneratedImageAsync(
|
||
DramaLingDbContext dbContext,
|
||
IImageStorageService storageService,
|
||
IImageProcessingService imageProcessingService,
|
||
ImageGenerationRequest request,
|
||
string optimizedPrompt,
|
||
ReplicateImageResult imageResult)
|
||
{
|
||
// 下載原圖 (1024x1024)
|
||
using var httpClient = new HttpClient();
|
||
var originalBytes = await httpClient.GetByteArrayAsync(imageResult.ImageUrl);
|
||
|
||
_logger.LogInformation("Downloaded original image: {OriginalSize}KB", originalBytes.Length / 1024);
|
||
|
||
// 壓縮為 512x512
|
||
var resizedBytes = await imageProcessingService.ResizeImageAsync(originalBytes, 512, 512);
|
||
var imageStream = new MemoryStream(resizedBytes);
|
||
|
||
// 生成檔案名稱
|
||
var fileName = $"{request.FlashcardId}_{Guid.NewGuid()}.png";
|
||
|
||
// 儲存到本地/雲端
|
||
var relativePath = await storageService.SaveImageAsync(imageStream, fileName);
|
||
|
||
// 建立 ExampleImage 記錄
|
||
var exampleImage = new ExampleImage
|
||
{
|
||
Id = Guid.NewGuid(),
|
||
RelativePath = relativePath,
|
||
AltText = $"Example image for {request.Flashcard?.Word}",
|
||
GeminiPrompt = request.GeminiPrompt,
|
||
GeminiDescription = request.GeneratedDescription,
|
||
ReplicatePrompt = optimizedPrompt,
|
||
ReplicateModel = "ideogram-v2a-turbo",
|
||
GeminiCost = request.GeminiCost ?? 0.002m,
|
||
ReplicateCost = imageResult.Cost,
|
||
TotalGenerationCost = (request.GeminiCost ?? 0.002m) + imageResult.Cost,
|
||
FileSize = resizedBytes.Length, // 使用壓縮後的檔案大小
|
||
ImageWidth = 512,
|
||
ImageHeight = 512,
|
||
ContentHash = ComputeHash(resizedBytes), // 使用壓縮後的檔案計算 hash
|
||
ModerationStatus = "pending",
|
||
CreatedAt = DateTime.UtcNow,
|
||
UpdatedAt = DateTime.UtcNow
|
||
};
|
||
|
||
dbContext.ExampleImages.Add(exampleImage);
|
||
|
||
// 建立詞卡圖片關聯
|
||
var flashcardImage = new FlashcardExampleImage
|
||
{
|
||
FlashcardId = request.FlashcardId,
|
||
ExampleImageId = exampleImage.Id,
|
||
DisplayOrder = 1,
|
||
IsPrimary = true,
|
||
ContextRelevance = 1.0m,
|
||
CreatedAt = DateTime.UtcNow
|
||
};
|
||
|
||
dbContext.FlashcardExampleImages.Add(flashcardImage);
|
||
await dbContext.SaveChangesAsync();
|
||
|
||
return exampleImage;
|
||
}
|
||
|
||
private async Task CompleteRequestAsync(DramaLingDbContext dbContext, Guid requestId, Guid imageId, long totalProcessingTimeMs)
|
||
{
|
||
var request = await dbContext.ImageGenerationRequests.FindAsync(requestId);
|
||
if (request == null) return;
|
||
|
||
request.OverallStatus = "completed";
|
||
request.ReplicateStatus = "completed";
|
||
request.GeneratedImageId = imageId;
|
||
request.CompletedAt = DateTime.UtcNow;
|
||
request.ReplicateCompletedAt = DateTime.UtcNow;
|
||
request.TotalProcessingTimeMs = (int)totalProcessingTimeMs;
|
||
request.TotalCost = (request.GeminiCost ?? 0) + (request.ReplicateCost ?? 0);
|
||
|
||
await dbContext.SaveChangesAsync();
|
||
}
|
||
|
||
private async Task MarkRequestAsFailedAsync(DramaLingDbContext dbContext, Guid requestId, string stage, string? errorMessage)
|
||
{
|
||
var request = await dbContext.ImageGenerationRequests.FindAsync(requestId);
|
||
if (request == null) return;
|
||
|
||
request.OverallStatus = "failed";
|
||
|
||
switch (stage.ToLower())
|
||
{
|
||
case "gemini":
|
||
request.GeminiStatus = "failed";
|
||
request.GeminiErrorMessage = errorMessage;
|
||
request.GeminiCompletedAt = DateTime.UtcNow;
|
||
break;
|
||
case "replicate":
|
||
request.ReplicateStatus = "failed";
|
||
request.ReplicateErrorMessage = errorMessage;
|
||
request.ReplicateCompletedAt = DateTime.UtcNow;
|
||
break;
|
||
default:
|
||
request.GeminiErrorMessage = errorMessage;
|
||
request.ReplicateErrorMessage = errorMessage;
|
||
break;
|
||
}
|
||
|
||
request.CompletedAt = DateTime.UtcNow;
|
||
|
||
await dbContext.SaveChangesAsync();
|
||
|
||
_logger.LogError("Generation request {RequestId} marked as failed at stage {Stage}: {Error}",
|
||
requestId, stage, errorMessage);
|
||
}
|
||
|
||
private static string ComputeHash(byte[] bytes)
|
||
{
|
||
using var sha256 = System.Security.Cryptography.SHA256.Create();
|
||
var hashBytes = sha256.ComputeHash(bytes);
|
||
return Convert.ToHexString(hashBytes);
|
||
}
|
||
} |