179 lines
6.9 KiB
C#
179 lines
6.9 KiB
C#
using DramaLing.Api.Data;
|
|
using DramaLing.Api.Models.DTOs;
|
|
using DramaLing.Api.Models.Entities;
|
|
using DramaLing.Api.Services.Storage;
|
|
using Microsoft.EntityFrameworkCore;
|
|
using System.Text.Json;
|
|
|
|
namespace DramaLing.Api.Services.AI.Generation;
|
|
|
|
public class ImageGenerationWorkflow : IImageGenerationWorkflow
|
|
{
|
|
private readonly IServiceProvider _serviceProvider;
|
|
private readonly IGenerationPipelineService _pipelineService;
|
|
private readonly ILogger<ImageGenerationWorkflow> _logger;
|
|
|
|
public ImageGenerationWorkflow(
|
|
IServiceProvider serviceProvider,
|
|
IGenerationPipelineService pipelineService,
|
|
ILogger<ImageGenerationWorkflow> logger)
|
|
{
|
|
_serviceProvider = serviceProvider ?? throw new ArgumentNullException(nameof(serviceProvider));
|
|
_pipelineService = pipelineService ?? throw new ArgumentNullException(nameof(pipelineService));
|
|
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
|
|
}
|
|
|
|
public async Task<GenerationRequestResult> StartGenerationAsync(Guid flashcardId, GenerationRequest request)
|
|
{
|
|
using var scope = _serviceProvider.CreateScope();
|
|
var dbContext = scope.ServiceProvider.GetRequiredService<DramaLingDbContext>();
|
|
|
|
try
|
|
{
|
|
// 檢查詞卡是否存在
|
|
var flashcard = await dbContext.Flashcards.FindAsync(flashcardId);
|
|
if (flashcard == null)
|
|
{
|
|
throw new ArgumentException($"Flashcard {flashcardId} not found");
|
|
}
|
|
|
|
// 建立生成請求記錄
|
|
var generationRequest = new ImageGenerationRequest
|
|
{
|
|
Id = Guid.NewGuid(),
|
|
UserId = request.UserId,
|
|
FlashcardId = flashcardId,
|
|
OverallStatus = "pending",
|
|
GeminiStatus = "pending",
|
|
ReplicateStatus = "pending",
|
|
OriginalRequest = JsonSerializer.Serialize(request),
|
|
CreatedAt = DateTime.UtcNow
|
|
};
|
|
|
|
dbContext.ImageGenerationRequests.Add(generationRequest);
|
|
await dbContext.SaveChangesAsync();
|
|
|
|
_logger.LogInformation("Created generation request {RequestId} for flashcard {FlashcardId}",
|
|
generationRequest.Id, flashcardId);
|
|
|
|
// 後台執行生成流程
|
|
_ = Task.Run(async () =>
|
|
{
|
|
try
|
|
{
|
|
await _pipelineService.ExecuteGenerationPipelineAsync(generationRequest.Id);
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError(ex, "Background generation pipeline failed for request {RequestId}", generationRequest.Id);
|
|
}
|
|
});
|
|
|
|
return new GenerationRequestResult
|
|
{
|
|
RequestId = generationRequest.Id,
|
|
OverallStatus = "pending",
|
|
CurrentStage = "description_generation",
|
|
EstimatedTimeMinutes = new EstimatedTimeDto
|
|
{
|
|
Gemini = 0.5,
|
|
Replicate = 2.0,
|
|
Total = 2.5
|
|
},
|
|
CostEstimate = new CostEstimateDto
|
|
{
|
|
Gemini = 0.002m,
|
|
Replicate = 0.025m,
|
|
Total = 0.027m
|
|
}
|
|
};
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError(ex, "Failed to start generation for flashcard {FlashcardId}", flashcardId);
|
|
throw;
|
|
}
|
|
}
|
|
|
|
public async Task<GenerationStatusResponse> GetGenerationStatusAsync(Guid requestId)
|
|
{
|
|
using var scope = _serviceProvider.CreateScope();
|
|
var dbContext = scope.ServiceProvider.GetRequiredService<DramaLingDbContext>();
|
|
var storageService = scope.ServiceProvider.GetRequiredService<IImageStorageService>();
|
|
|
|
var request = await dbContext.ImageGenerationRequests
|
|
.Include(r => r.GeneratedImage)
|
|
.FirstOrDefaultAsync(r => r.Id == requestId);
|
|
|
|
if (request == null)
|
|
{
|
|
throw new ArgumentException($"Generation request {requestId} not found");
|
|
}
|
|
|
|
return new GenerationStatusResponse
|
|
{
|
|
RequestId = request.Id,
|
|
OverallStatus = request.OverallStatus,
|
|
Stages = new StageStatusDto
|
|
{
|
|
Gemini = new GeminiStageDto
|
|
{
|
|
Status = request.GeminiStatus,
|
|
StartedAt = request.GeminiStartedAt,
|
|
CompletedAt = request.GeminiCompletedAt,
|
|
ProcessingTimeMs = request.GeminiProcessingTimeMs,
|
|
Cost = request.GeminiCost,
|
|
GeneratedDescription = request.GeneratedDescription
|
|
},
|
|
Replicate = new ReplicateStageDto
|
|
{
|
|
Status = request.ReplicateStatus,
|
|
StartedAt = request.ReplicateStartedAt,
|
|
CompletedAt = request.ReplicateCompletedAt,
|
|
ProcessingTimeMs = request.ReplicateProcessingTimeMs,
|
|
Cost = request.ReplicateCost
|
|
}
|
|
},
|
|
TotalCost = request.TotalCost,
|
|
CompletedAt = request.CompletedAt,
|
|
Result = request.GeneratedImage != null ? new GenerationResultDto
|
|
{
|
|
ImageUrl = await storageService.GetImageUrlAsync(request.GeneratedImage.RelativePath),
|
|
ImageId = request.GeneratedImage.Id.ToString(),
|
|
QualityScore = request.GeneratedImage.QualityScore,
|
|
Dimensions = new DimensionsDto
|
|
{
|
|
Width = request.GeneratedImage.ImageWidth ?? 512,
|
|
Height = request.GeneratedImage.ImageHeight ?? 512
|
|
},
|
|
FileSize = request.GeneratedImage.FileSize
|
|
} : null
|
|
};
|
|
}
|
|
|
|
public async Task<bool> CancelGenerationAsync(Guid requestId)
|
|
{
|
|
using var scope = _serviceProvider.CreateScope();
|
|
var dbContext = scope.ServiceProvider.GetRequiredService<DramaLingDbContext>();
|
|
|
|
try
|
|
{
|
|
var request = await dbContext.ImageGenerationRequests.FindAsync(requestId);
|
|
if (request == null || request.OverallStatus == "completed")
|
|
{
|
|
return false;
|
|
}
|
|
|
|
request.OverallStatus = "cancelled";
|
|
await dbContext.SaveChangesAsync();
|
|
|
|
_logger.LogInformation("Generation request {RequestId} cancelled", requestId);
|
|
return true;
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError(ex, "Failed to cancel generation request {RequestId}", requestId);
|
|
return false;
|
|
}
|
|
}
|
|
} |