using DramaLing.Api.Models.Configuration; using DramaLing.Api.Models.DTOs; using Microsoft.Extensions.Options; using System.Diagnostics; using System.Text; using System.Text.Json; namespace DramaLing.Api.Services.AI; public class ReplicateImageGenerationService : IReplicateImageGenerationService { private readonly HttpClient _httpClient; private readonly ReplicateOptions _options; private readonly ILogger _logger; public ReplicateImageGenerationService( HttpClient httpClient, IOptions options, ILogger logger) { _httpClient = httpClient ?? throw new ArgumentNullException(nameof(httpClient)); _options = options.Value ?? throw new ArgumentNullException(nameof(options)); _logger = logger ?? throw new ArgumentNullException(nameof(logger)); _httpClient.Timeout = TimeSpan.FromSeconds(_options.TimeoutSeconds); _httpClient.DefaultRequestHeaders.Add("Authorization", $"Token {_options.ApiKey}"); _httpClient.DefaultRequestHeaders.Add("User-Agent", "DramaLing/1.0"); } public async Task GenerateImageAsync( string prompt, string model, GenerationOptionsDto options) { var stopwatch = Stopwatch.StartNew(); try { _logger.LogInformation("Starting Replicate image generation with model {Model}", model); // 1. 啟動 Replicate 預測 var prediction = await StartPredictionAsync(prompt, model, options); // 2. 輪詢檢查生成狀態 var result = await WaitForCompletionAsync(prediction.Id, options.MaxRetries * 60); result.ProcessingTimeMs = (int)stopwatch.ElapsedMilliseconds; _logger.LogInformation("Replicate image generation completed in {ElapsedMs}ms", stopwatch.ElapsedMilliseconds); return result; } catch (Exception ex) { stopwatch.Stop(); _logger.LogError(ex, "Replicate image generation failed"); return new ImageGenerationResult { Success = false, Error = ex.Message, ProcessingTimeMs = (int)stopwatch.ElapsedMilliseconds }; } } public async Task GetPredictionStatusAsync(string predictionId) { try { var response = await _httpClient.GetAsync($"{_options.BaseUrl}/predictions/{predictionId}"); response.EnsureSuccessStatusCode(); var json = await response.Content.ReadAsStringAsync(); var prediction = JsonSerializer.Deserialize(json); return new ReplicatePredictionStatus { Status = prediction?.Status ?? "unknown", Output = prediction?.Output, Error = prediction?.Error, Version = prediction?.Version, Metrics = prediction?.Metrics, CompletedAt = prediction?.CompletedAt }; } catch (Exception ex) { _logger.LogError(ex, "Failed to get prediction status for {PredictionId}", predictionId); throw; } } private async Task StartPredictionAsync( string prompt, string model, GenerationOptionsDto options) { var requestBody = BuildModelRequest(prompt, model, options); // 使用模型特定的 API 端點 var apiUrl = GetModelApiUrl(model); var json = JsonSerializer.Serialize(requestBody); var content = new StringContent(json, Encoding.UTF8, "application/json"); _logger.LogDebug("Replicate API request to {ApiUrl}: {Request}", apiUrl, json); var response = await _httpClient.PostAsync(apiUrl, content); response.EnsureSuccessStatusCode(); var responseJson = await response.Content.ReadAsStringAsync(); var prediction = JsonSerializer.Deserialize(responseJson); if (prediction == null) { throw new InvalidOperationException("Failed to parse Replicate prediction response"); } return prediction; } private string GetModelApiUrl(string model) { return model.ToLower() switch { "ideogram-v2a-turbo" => "https://api.replicate.com/v1/models/ideogram-ai/ideogram-v2a-turbo/predictions", _ => $"{_options.BaseUrl}/predictions" }; } private object BuildModelRequest(string prompt, string model, GenerationOptionsDto options) { if (!_options.Models.TryGetValue(model, out var modelConfig)) { throw new ArgumentException($"Model {model} is not configured"); } return model.ToLower() switch { "ideogram-v2a-turbo" => new { input = new { prompt = prompt, width = options.MaxRetries > 0 ? modelConfig.DefaultWidth : 512, height = options.MaxRetries > 0 ? modelConfig.DefaultHeight : 512, magic_prompt_option = "Auto", style_type = modelConfig.StyleType ?? "General", aspect_ratio = modelConfig.AspectRatio ?? "ASPECT_1_1", model = modelConfig.Model ?? "V_2_TURBO", seed = Random.Shared.Next() } }, "flux-1-dev" => new { input = new { prompt = prompt, width = modelConfig.DefaultWidth, height = modelConfig.DefaultHeight, num_outputs = 1, guidance_scale = 3.5, num_inference_steps = 28, seed = Random.Shared.Next() } }, "stable-diffusion-xl" => new { input = new { prompt = prompt, width = modelConfig.DefaultWidth, height = modelConfig.DefaultHeight, num_outputs = 1, scheduler = "K_EULER_ANCESTRAL", num_inference_steps = 25, guidance_scale = 7.5, seed = Random.Shared.Next() } }, _ => throw new NotSupportedException($"Model {model} not supported") }; } private async Task WaitForCompletionAsync(string predictionId, int timeoutMinutes) { var timeout = TimeSpan.FromMinutes(timeoutMinutes); var pollInterval = TimeSpan.FromSeconds(3); var startTime = DateTime.UtcNow; while (DateTime.UtcNow - startTime < timeout) { var status = await GetPredictionStatusAsync(predictionId); switch (status.Status.ToLower()) { case "succeeded": return new ImageGenerationResult { Success = true, ImageUrl = status.Output?.FirstOrDefault(), Cost = CalculateReplicateCost(status.Metrics), ModelVersion = status.Version, Metadata = status.Metrics }; case "failed": return new ImageGenerationResult { Success = false, Error = status.Error ?? "Generation failed with unknown error" }; case "processing": case "starting": _logger.LogDebug("Replicate prediction {PredictionId} still processing", predictionId); await Task.Delay(pollInterval); break; default: _logger.LogWarning("Unknown prediction status: {Status}", status.Status); await Task.Delay(pollInterval); break; } } return new ImageGenerationResult { Success = false, Error = "Generation timeout exceeded" }; } private decimal CalculateReplicateCost(Dictionary? metrics) { // 從配置中獲取預設成本,實際部署時可根據 metrics 精確計算 if (_options.Models.TryGetValue(_options.DefaultModel, out var modelConfig)) { return modelConfig.CostPerGeneration; } return 0.025m; // 預設 Ideogram 成本 } }