dramaling-vocab-learning/backend/DramaLing.Api/Services/ReplicateService.cs

249 lines
9.1 KiB
C#

using DramaLing.Api.Models.Configuration;
using DramaLing.Api.Models.DTOs;
using Microsoft.Extensions.Options;
using System.Diagnostics;
using System.Text;
using System.Text.Json;
namespace DramaLing.Api.Services;
public interface IReplicateService
{
Task<ReplicateImageResult> GenerateImageAsync(string prompt, string model, ReplicateGenerationOptions options);
Task<ReplicatePredictionStatus> GetPredictionStatusAsync(string predictionId);
}
public class ReplicateService : IReplicateService
{
private readonly HttpClient _httpClient;
private readonly ILogger<ReplicateService> _logger;
private readonly ReplicateOptions _options;
public ReplicateService(HttpClient httpClient, IOptions<ReplicateOptions> options, ILogger<ReplicateService> logger)
{
_httpClient = httpClient ?? throw new ArgumentNullException(nameof(httpClient));
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
_options = options.Value ?? throw new ArgumentNullException(nameof(options));
_logger.LogInformation("ReplicateService initialized with default model: {Model}, timeout: {Timeout}s",
_options.DefaultModel, _options.TimeoutSeconds);
_httpClient.Timeout = TimeSpan.FromSeconds(_options.TimeoutSeconds);
_httpClient.DefaultRequestHeaders.Add("Authorization", $"Token {_options.ApiKey}");
_httpClient.DefaultRequestHeaders.Add("User-Agent", "DramaLing/1.0");
}
public async Task<ReplicateImageResult> GenerateImageAsync(string prompt, string model, ReplicateGenerationOptions options)
{
var stopwatch = Stopwatch.StartNew();
try
{
_logger.LogInformation("Starting Replicate image generation with model {Model}", model);
// 啟動 Replicate 預測
var prediction = await StartPredictionAsync(prompt, model, options);
// 輪詢檢查生成狀態
var result = await WaitForCompletionAsync(prediction.Id, options.TimeoutMinutes);
result.ProcessingTimeMs = (int)stopwatch.ElapsedMilliseconds;
_logger.LogInformation("Replicate image generation completed in {ElapsedMs}ms", stopwatch.ElapsedMilliseconds);
return result;
}
catch (Exception ex)
{
stopwatch.Stop();
_logger.LogError(ex, "Replicate image generation failed");
return new ReplicateImageResult
{
Success = false,
Error = ex.Message,
ProcessingTimeMs = (int)stopwatch.ElapsedMilliseconds
};
}
}
public async Task<ReplicatePredictionStatus> GetPredictionStatusAsync(string predictionId)
{
try
{
var response = await _httpClient.GetAsync($"{_options.BaseUrl}/predictions/{predictionId}");
response.EnsureSuccessStatusCode();
var json = await response.Content.ReadAsStringAsync();
var prediction = JsonSerializer.Deserialize<ReplicatePrediction>(json);
return new ReplicatePredictionStatus
{
Status = prediction?.Status ?? "unknown",
Output = prediction?.Output,
Error = prediction?.Error,
Version = prediction?.Version,
Metrics = prediction?.Metrics,
CompletedAt = prediction?.CompletedAt
};
}
catch (Exception ex)
{
_logger.LogError(ex, "Failed to get prediction status for {PredictionId}", predictionId);
throw;
}
}
private async Task<ReplicatePrediction> StartPredictionAsync(string prompt, string model, ReplicateGenerationOptions options)
{
var requestBody = BuildModelRequest(prompt, model, options);
var apiUrl = GetModelApiUrl(model);
var json = JsonSerializer.Serialize(requestBody);
var content = new StringContent(json, Encoding.UTF8, "application/json");
_logger.LogDebug("Replicate API request to {ApiUrl}", apiUrl);
var response = await _httpClient.PostAsync(apiUrl, content);
response.EnsureSuccessStatusCode();
var responseJson = await response.Content.ReadAsStringAsync();
var prediction = JsonSerializer.Deserialize<ReplicatePrediction>(responseJson);
if (prediction == null)
{
throw new InvalidOperationException("Failed to parse Replicate prediction response");
}
return prediction;
}
private string GetModelApiUrl(string model)
{
return model.ToLower() switch
{
"ideogram-v2a-turbo" => "https://api.replicate.com/v1/models/ideogram-ai/ideogram-v2a-turbo/predictions",
_ => $"{_options.BaseUrl}/predictions"
};
}
private object BuildModelRequest(string prompt, string model, ReplicateGenerationOptions options)
{
if (!_options.Models.TryGetValue(model, out var modelConfig))
{
throw new ArgumentException($"Model {model} is not configured");
}
return model.ToLower() switch
{
"ideogram-v2a-turbo" => new
{
input = new
{
prompt = prompt,
width = options.Width ?? modelConfig.DefaultWidth,
height = options.Height ?? modelConfig.DefaultHeight,
magic_prompt_option = "Auto",
style_type = modelConfig.StyleType ?? "General",
aspect_ratio = modelConfig.AspectRatio ?? "ASPECT_1_1",
model = modelConfig.Model ?? "V_2_TURBO",
seed = options.Seed ?? Random.Shared.Next()
}
},
"flux-1-dev" => new
{
input = new
{
prompt = prompt,
width = modelConfig.DefaultWidth,
height = modelConfig.DefaultHeight,
num_outputs = 1,
guidance_scale = 3.5,
num_inference_steps = 28,
seed = options.Seed ?? Random.Shared.Next()
}
},
_ => throw new NotSupportedException($"Model {model} not supported")
};
}
private async Task<ReplicateImageResult> WaitForCompletionAsync(string predictionId, int timeoutMinutes)
{
var timeout = TimeSpan.FromMinutes(timeoutMinutes);
var pollInterval = TimeSpan.FromSeconds(3);
var startTime = DateTime.UtcNow;
while (DateTime.UtcNow - startTime < timeout)
{
var status = await GetPredictionStatusAsync(predictionId);
switch (status.Status.ToLower())
{
case "succeeded":
return new ReplicateImageResult
{
Success = true,
ImageUrl = status.Output?.FirstOrDefault(),
Cost = CalculateReplicateCost(status.Metrics),
ModelVersion = status.Version,
Metadata = status.Metrics
};
case "failed":
return new ReplicateImageResult
{
Success = false,
Error = status.Error ?? "Generation failed with unknown error"
};
case "processing":
case "starting":
_logger.LogDebug("Replicate prediction {PredictionId} still processing", predictionId);
await Task.Delay(pollInterval);
break;
default:
_logger.LogWarning("Unknown prediction status: {Status}", status.Status);
await Task.Delay(pollInterval);
break;
}
}
return new ReplicateImageResult
{
Success = false,
Error = "Generation timeout exceeded"
};
}
private decimal CalculateReplicateCost(Dictionary<string, object>? metrics)
{
// 從配置中獲取預設成本
if (_options.Models.TryGetValue(_options.DefaultModel, out var modelConfig))
{
return modelConfig.CostPerGeneration;
}
return 0.025m; // 預設 Ideogram 成本
}
}
// Response models for ReplicateService
public class ReplicateImageResult
{
public bool Success { get; set; }
public string? ImageUrl { get; set; }
public decimal Cost { get; set; }
public int ProcessingTimeMs { get; set; }
public string? ModelVersion { get; set; }
public string? Error { get; set; }
public Dictionary<string, object>? Metadata { get; set; }
}
public class ReplicateGenerationOptions
{
public int? Width { get; set; }
public int? Height { get; set; }
public int? Seed { get; set; }
public int TimeoutMinutes { get; set; } = 5;
}