116 lines
4.4 KiB
C#
116 lines
4.4 KiB
C#
using DramaLing.Api.Data;
|
|
using Microsoft.EntityFrameworkCore;
|
|
|
|
namespace DramaLing.Api.Services.AI.Generation;
|
|
|
|
public class GenerationStateManager : IGenerationStateManager
|
|
{
|
|
private readonly IServiceProvider _serviceProvider;
|
|
private readonly ILogger<GenerationStateManager> _logger;
|
|
|
|
public GenerationStateManager(
|
|
IServiceProvider serviceProvider,
|
|
ILogger<GenerationStateManager> logger)
|
|
{
|
|
_serviceProvider = serviceProvider ?? throw new ArgumentNullException(nameof(serviceProvider));
|
|
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
|
|
}
|
|
|
|
public async Task UpdateRequestStatusAsync(Guid requestId, string overallStatus, string geminiStatus, string replicateStatus)
|
|
{
|
|
using var scope = _serviceProvider.CreateScope();
|
|
var dbContext = scope.ServiceProvider.GetRequiredService<DramaLingDbContext>();
|
|
|
|
var request = await dbContext.ImageGenerationRequests.FindAsync(requestId);
|
|
if (request == null) return;
|
|
|
|
request.OverallStatus = overallStatus;
|
|
request.GeminiStatus = geminiStatus;
|
|
request.ReplicateStatus = replicateStatus;
|
|
|
|
if (geminiStatus == "processing" && request.GeminiStartedAt == null)
|
|
{
|
|
request.GeminiStartedAt = DateTime.UtcNow;
|
|
}
|
|
|
|
if (replicateStatus == "processing" && request.ReplicateStartedAt == null)
|
|
{
|
|
request.ReplicateStartedAt = DateTime.UtcNow;
|
|
}
|
|
|
|
await dbContext.SaveChangesAsync();
|
|
}
|
|
|
|
public async Task UpdateGeminiResultAsync(Guid requestId, string optimizedPrompt)
|
|
{
|
|
using var scope = _serviceProvider.CreateScope();
|
|
var dbContext = scope.ServiceProvider.GetRequiredService<DramaLingDbContext>();
|
|
|
|
var request = await dbContext.ImageGenerationRequests.FindAsync(requestId);
|
|
if (request == null) return;
|
|
|
|
request.GeminiStatus = "completed";
|
|
request.GeminiCompletedAt = DateTime.UtcNow;
|
|
request.GeneratedDescription = "Gemini generated description";
|
|
request.FinalReplicatePrompt = optimizedPrompt;
|
|
request.GeminiCost = 0.002m;
|
|
request.GeminiProcessingTimeMs = 30000;
|
|
|
|
await dbContext.SaveChangesAsync();
|
|
}
|
|
|
|
public async Task CompleteRequestAsync(Guid requestId, Guid imageId, long totalProcessingTimeMs)
|
|
{
|
|
using var scope = _serviceProvider.CreateScope();
|
|
var dbContext = scope.ServiceProvider.GetRequiredService<DramaLingDbContext>();
|
|
|
|
var request = await dbContext.ImageGenerationRequests.FindAsync(requestId);
|
|
if (request == null) return;
|
|
|
|
request.OverallStatus = "completed";
|
|
request.ReplicateStatus = "completed";
|
|
request.GeneratedImageId = imageId;
|
|
request.CompletedAt = DateTime.UtcNow;
|
|
request.ReplicateCompletedAt = DateTime.UtcNow;
|
|
request.TotalProcessingTimeMs = (int)totalProcessingTimeMs;
|
|
request.TotalCost = (request.GeminiCost ?? 0) + (request.ReplicateCost ?? 0);
|
|
|
|
await dbContext.SaveChangesAsync();
|
|
}
|
|
|
|
public async Task MarkRequestAsFailedAsync(Guid requestId, string stage, string? errorMessage)
|
|
{
|
|
using var scope = _serviceProvider.CreateScope();
|
|
var dbContext = scope.ServiceProvider.GetRequiredService<DramaLingDbContext>();
|
|
|
|
var request = await dbContext.ImageGenerationRequests.FindAsync(requestId);
|
|
if (request == null) return;
|
|
|
|
request.OverallStatus = "failed";
|
|
|
|
switch (stage.ToLower())
|
|
{
|
|
case "gemini":
|
|
request.GeminiStatus = "failed";
|
|
request.GeminiErrorMessage = errorMessage;
|
|
request.GeminiCompletedAt = DateTime.UtcNow;
|
|
break;
|
|
case "replicate":
|
|
request.ReplicateStatus = "failed";
|
|
request.ReplicateErrorMessage = errorMessage;
|
|
request.ReplicateCompletedAt = DateTime.UtcNow;
|
|
break;
|
|
default:
|
|
request.GeminiErrorMessage = errorMessage;
|
|
request.ReplicateErrorMessage = errorMessage;
|
|
break;
|
|
}
|
|
|
|
request.CompletedAt = DateTime.UtcNow;
|
|
|
|
await dbContext.SaveChangesAsync();
|
|
|
|
_logger.LogError("Generation request {RequestId} marked as failed at stage {Stage}: {Error}",
|
|
requestId, stage, errorMessage);
|
|
}
|
|
} |