dramaling-vocab-learning/backend/DramaLing.Api/Services/AI/Generation/GenerationPipelineService.cs

115 lines
5.0 KiB
C#
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

using DramaLing.Api.Data;
using DramaLing.Api.Models.DTOs;
using DramaLing.Api.Services.Storage;
using DramaLing.Api.Services;
using Microsoft.EntityFrameworkCore;
using System.Diagnostics;
using System.Text.Json;
namespace DramaLing.Api.Services.AI.Generation;
public class GenerationPipelineService : IGenerationPipelineService
{
private readonly IServiceProvider _serviceProvider;
private readonly IGenerationStateManager _stateManager;
private readonly IImageSaveManager _imageSaveManager;
private readonly ILogger<GenerationPipelineService> _logger;
public GenerationPipelineService(
IServiceProvider serviceProvider,
IGenerationStateManager stateManager,
IImageSaveManager imageSaveManager,
ILogger<GenerationPipelineService> logger)
{
_serviceProvider = serviceProvider ?? throw new ArgumentNullException(nameof(serviceProvider));
_stateManager = stateManager ?? throw new ArgumentNullException(nameof(stateManager));
_imageSaveManager = imageSaveManager ?? throw new ArgumentNullException(nameof(imageSaveManager));
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
}
public async Task ExecuteGenerationPipelineAsync(Guid requestId)
{
var totalStopwatch = Stopwatch.StartNew();
using var scope = _serviceProvider.CreateScope();
var dbContext = scope.ServiceProvider.GetRequiredService<DramaLingDbContext>();
var geminiService = scope.ServiceProvider.GetRequiredService<IGeminiService>();
var replicateService = scope.ServiceProvider.GetRequiredService<IReplicateService>();
var storageService = scope.ServiceProvider.GetRequiredService<IImageStorageService>();
var imageProcessingService = scope.ServiceProvider.GetRequiredService<IImageProcessingService>();
try
{
_logger.LogInformation("Starting generation pipeline for request {RequestId}", requestId);
var request = await dbContext.ImageGenerationRequests
.Include(r => r.Flashcard)
.FirstOrDefaultAsync(r => r.Id == requestId);
if (request == null)
{
_logger.LogError("Generation request {RequestId} not found in pipeline", requestId);
return;
}
var options = JsonSerializer.Deserialize<GenerationRequest>(request.OriginalRequest);
// 第一階段Gemini 描述生成
_logger.LogInformation("Starting Gemini description generation for request {RequestId}", requestId);
await _stateManager.UpdateRequestStatusAsync(requestId, "description_generating", "processing", "pending");
var optimizedPrompt = await geminiService.GenerateImageDescriptionAsync(
request.Flashcard,
options?.Options ?? new GenerationOptionsDto());
if (string.IsNullOrWhiteSpace(optimizedPrompt))
{
await _stateManager.MarkRequestAsFailedAsync(requestId, "gemini", "Generated prompt is empty");
return;
}
await _stateManager.UpdateGeminiResultAsync(requestId, optimizedPrompt);
// 第二階段Replicate 圖片生成
_logger.LogInformation("Starting Replicate image generation for request {RequestId}", requestId);
await _stateManager.UpdateRequestStatusAsync(requestId, "image_generating", "completed", "processing");
var modelName = "ideogram-v2a-turbo";
_logger.LogInformation("Using Replicate model: {ModelName}", modelName);
var imageResult = await replicateService.GenerateImageAsync(
optimizedPrompt,
modelName,
new ReplicateGenerationOptions
{
Width = options?.Width ?? 512,
Height = options?.Height ?? 512,
TimeoutMinutes = 5
});
if (!imageResult.Success)
{
await _stateManager.MarkRequestAsFailedAsync(requestId, "replicate", imageResult.Error);
return;
}
// 下載並儲存圖片
var savedImage = await _imageSaveManager.SaveGeneratedImageAsync(
dbContext, storageService, imageProcessingService, request, optimizedPrompt, imageResult);
// 完成請求
await _stateManager.CompleteRequestAsync(requestId, savedImage.Id, totalStopwatch.ElapsedMilliseconds);
_logger.LogInformation("Generation pipeline completed successfully for request {RequestId} in {ElapsedMs}ms",
requestId, totalStopwatch.ElapsedMilliseconds);
}
catch (Exception ex)
{
totalStopwatch.Stop();
_logger.LogError(ex, "Generation pipeline failed for request {RequestId}", requestId);
await _stateManager.MarkRequestAsFailedAsync(requestId, "system", ex.Message);
}
}
}