using DramaLing.Api.Models.Configuration; using DramaLing.Api.Models.DTOs; using Microsoft.Extensions.Options; using System.Diagnostics; using System.Text; using System.Text.Json; namespace DramaLing.Api.Services; public interface IReplicateService { Task GenerateImageAsync(string prompt, string model, ReplicateGenerationOptions options); Task GetPredictionStatusAsync(string predictionId); } public class ReplicateService : IReplicateService { private readonly HttpClient _httpClient; private readonly ILogger _logger; private readonly ReplicateOptions _options; public ReplicateService(HttpClient httpClient, IOptions options, ILogger logger) { _httpClient = httpClient ?? throw new ArgumentNullException(nameof(httpClient)); _logger = logger ?? throw new ArgumentNullException(nameof(logger)); _options = options.Value ?? throw new ArgumentNullException(nameof(options)); _logger.LogInformation("ReplicateService initialized with default model: {Model}, timeout: {Timeout}s", _options.DefaultModel, _options.TimeoutSeconds); _httpClient.Timeout = TimeSpan.FromSeconds(_options.TimeoutSeconds); _httpClient.DefaultRequestHeaders.Add("Authorization", $"Token {_options.ApiKey}"); _httpClient.DefaultRequestHeaders.Add("User-Agent", "DramaLing/1.0"); } public async Task GenerateImageAsync(string prompt, string model, ReplicateGenerationOptions options) { var stopwatch = Stopwatch.StartNew(); try { _logger.LogInformation("Starting Replicate image generation with model {Model}", model); // 啟動 Replicate 預測 var prediction = await StartPredictionAsync(prompt, model, options); // 輪詢檢查生成狀態 var result = await WaitForCompletionAsync(prediction.Id, options.TimeoutMinutes); result.ProcessingTimeMs = (int)stopwatch.ElapsedMilliseconds; _logger.LogInformation("Replicate image generation completed in {ElapsedMs}ms", stopwatch.ElapsedMilliseconds); return result; } catch (Exception ex) { stopwatch.Stop(); _logger.LogError(ex, "Replicate image generation failed"); return new ReplicateImageResult { Success = false, Error = ex.Message, ProcessingTimeMs = (int)stopwatch.ElapsedMilliseconds }; } } public async Task GetPredictionStatusAsync(string predictionId) { try { var response = await _httpClient.GetAsync($"{_options.BaseUrl}/predictions/{predictionId}"); response.EnsureSuccessStatusCode(); var json = await response.Content.ReadAsStringAsync(); var prediction = JsonSerializer.Deserialize(json); return new ReplicatePredictionStatus { Status = prediction?.Status ?? "unknown", Output = prediction?.Output, Error = prediction?.Error, Version = prediction?.Version, Metrics = prediction?.Metrics, CompletedAt = prediction?.CompletedAt }; } catch (Exception ex) { _logger.LogError(ex, "Failed to get prediction status for {PredictionId}", predictionId); throw; } } private async Task StartPredictionAsync(string prompt, string model, ReplicateGenerationOptions options) { var requestBody = BuildModelRequest(prompt, model, options); var apiUrl = GetModelApiUrl(model); var json = JsonSerializer.Serialize(requestBody); var content = new StringContent(json, Encoding.UTF8, "application/json"); _logger.LogDebug("Replicate API request to {ApiUrl}", apiUrl); var response = await _httpClient.PostAsync(apiUrl, content); response.EnsureSuccessStatusCode(); var responseJson = await response.Content.ReadAsStringAsync(); var prediction = JsonSerializer.Deserialize(responseJson); if (prediction == null) { throw new InvalidOperationException("Failed to parse Replicate prediction response"); } return prediction; } private string GetModelApiUrl(string model) { return model.ToLower() switch { "ideogram-v2a-turbo" => "https://api.replicate.com/v1/models/ideogram-ai/ideogram-v2a-turbo/predictions", _ => $"{_options.BaseUrl}/predictions" }; } private object BuildModelRequest(string prompt, string model, ReplicateGenerationOptions options) { if (!_options.Models.TryGetValue(model, out var modelConfig)) { throw new ArgumentException($"Model {model} is not configured"); } return model.ToLower() switch { "ideogram-v2a-turbo" => new { input = new { prompt = prompt, width = options.Width ?? modelConfig.DefaultWidth, height = options.Height ?? modelConfig.DefaultHeight, magic_prompt_option = "Auto", style_type = modelConfig.StyleType ?? "General", aspect_ratio = modelConfig.AspectRatio ?? "ASPECT_1_1", model = modelConfig.Model ?? "V_2_TURBO", seed = options.Seed ?? Random.Shared.Next() } }, "flux-1-dev" => new { input = new { prompt = prompt, width = modelConfig.DefaultWidth, height = modelConfig.DefaultHeight, num_outputs = 1, guidance_scale = 3.5, num_inference_steps = 28, seed = options.Seed ?? Random.Shared.Next() } }, _ => throw new NotSupportedException($"Model {model} not supported") }; } private async Task WaitForCompletionAsync(string predictionId, int timeoutMinutes) { var timeout = TimeSpan.FromMinutes(timeoutMinutes); var pollInterval = TimeSpan.FromSeconds(3); var startTime = DateTime.UtcNow; while (DateTime.UtcNow - startTime < timeout) { var status = await GetPredictionStatusAsync(predictionId); switch (status.Status.ToLower()) { case "succeeded": return new ReplicateImageResult { Success = true, ImageUrl = status.Output?.FirstOrDefault(), Cost = CalculateReplicateCost(status.Metrics), ModelVersion = status.Version, Metadata = status.Metrics }; case "failed": return new ReplicateImageResult { Success = false, Error = status.Error ?? "Generation failed with unknown error" }; case "processing": case "starting": _logger.LogDebug("Replicate prediction {PredictionId} still processing", predictionId); await Task.Delay(pollInterval); break; default: _logger.LogWarning("Unknown prediction status: {Status}", status.Status); await Task.Delay(pollInterval); break; } } return new ReplicateImageResult { Success = false, Error = "Generation timeout exceeded" }; } private decimal CalculateReplicateCost(Dictionary? metrics) { // 從配置中獲取預設成本 if (_options.Models.TryGetValue(_options.DefaultModel, out var modelConfig)) { return modelConfig.CostPerGeneration; } return 0.025m; // 預設 Ideogram 成本 } } // Response models for ReplicateService public class ReplicateImageResult { public bool Success { get; set; } public string? ImageUrl { get; set; } public decimal Cost { get; set; } public int ProcessingTimeMs { get; set; } public string? ModelVersion { get; set; } public string? Error { get; set; } public Dictionary? Metadata { get; set; } } public class ReplicateGenerationOptions { public int? Width { get; set; } public int? Height { get; set; } public int? Seed { get; set; } public int TimeoutMinutes { get; set; } = 5; }