using DramaLing.Api.Models.Configuration; using DramaLing.Api.Models.DTOs; using Microsoft.Extensions.Options; using System.Diagnostics; using System.Text; using System.Text.Json; namespace DramaLing.Api.Services; public interface IReplicateService { Task GenerateImageAsync(string prompt, string model, ReplicateGenerationOptions options); Task GetPredictionStatusAsync(string predictionId); } public class ReplicateService : IReplicateService { private readonly HttpClient _httpClient; private readonly ILogger _logger; private readonly ReplicateOptions _options; public ReplicateService(HttpClient httpClient, IOptions options, ILogger logger) { _httpClient = httpClient ?? throw new ArgumentNullException(nameof(httpClient)); _logger = logger ?? throw new ArgumentNullException(nameof(logger)); _options = options.Value ?? throw new ArgumentNullException(nameof(options)); _logger.LogInformation("ReplicateService initialized with default model: {Model}, timeout: {Timeout}s", _options.DefaultModel, _options.TimeoutSeconds); _httpClient.Timeout = TimeSpan.FromSeconds(_options.TimeoutSeconds); _httpClient.DefaultRequestHeaders.Add("Authorization", $"Token {_options.ApiKey}"); _httpClient.DefaultRequestHeaders.Add("User-Agent", "DramaLing/1.0"); _httpClient.DefaultRequestHeaders.Add("Prefer", "wait"); // 添加你使用的 header } public async Task GenerateImageAsync(string prompt, string model, ReplicateGenerationOptions options) { var stopwatch = Stopwatch.StartNew(); try { _logger.LogInformation("Starting Replicate image generation with model {Model}", model); // 啟動 Replicate 預測 var prediction = await StartPredictionAsync(prompt, model, options); // 輪詢檢查生成狀態 var result = await WaitForCompletionAsync(prediction.Id, options.TimeoutMinutes); result.ProcessingTimeMs = (int)stopwatch.ElapsedMilliseconds; _logger.LogInformation("Replicate image generation completed in {ElapsedMs}ms", stopwatch.ElapsedMilliseconds); return result; } catch (Exception ex) { stopwatch.Stop(); _logger.LogError(ex, "Replicate image generation failed"); return new ReplicateImageResult { Success = false, Error = ex.Message, ProcessingTimeMs = (int)stopwatch.ElapsedMilliseconds }; } } public async Task GetPredictionStatusAsync(string predictionId) { try { var response = await _httpClient.GetAsync($"{_options.BaseUrl}/predictions/{predictionId}"); response.EnsureSuccessStatusCode(); var json = await response.Content.ReadAsStringAsync(); // 記錄實際收到的 JSON 格式用於除錯 _logger.LogDebug("Replicate API response for prediction {PredictionId}: {Response}", predictionId, json.Substring(0, Math.Min(500, json.Length))); var prediction = JsonSerializer.Deserialize(json); return new ReplicatePredictionStatus { Status = prediction?.Status ?? "unknown", Output = prediction?.Output, Error = prediction?.Error, Version = prediction?.Version, Metrics = prediction?.Metrics, CompletedAt = prediction?.CompletedAt }; } catch (Exception ex) { _logger.LogError(ex, "Failed to get prediction status for {PredictionId}", predictionId); throw; } } private async Task StartPredictionAsync(string prompt, string model, ReplicateGenerationOptions options) { var requestBody = BuildModelRequest(prompt, model, options); var apiUrl = GetModelApiUrl(model); var json = JsonSerializer.Serialize(requestBody); var content = new StringContent(json, Encoding.UTF8, "application/json"); _logger.LogDebug("Replicate API request to {ApiUrl}", apiUrl); var response = await _httpClient.PostAsync(apiUrl, content); response.EnsureSuccessStatusCode(); var responseJson = await response.Content.ReadAsStringAsync(); var prediction = JsonSerializer.Deserialize(responseJson); if (prediction == null) { throw new InvalidOperationException("Failed to parse Replicate prediction response"); } return prediction; } private string GetModelApiUrl(string model) { return model.ToLower() switch { "ideogram-v2a-turbo" => "https://api.replicate.com/v1/models/ideogram-ai/ideogram-v2a-turbo/predictions", _ => $"{_options.BaseUrl}/predictions" }; } private object BuildModelRequest(string prompt, string model, ReplicateGenerationOptions options) { // 使用你確認可行的簡化格式 return model.ToLower() switch { "ideogram-v2a-turbo" => new { input = new { prompt = prompt, aspect_ratio = "1:1" // 簡化為你確認可行的格式 } }, "flux-1-dev" => new { input = new { prompt = prompt, width = 512, height = 512, num_outputs = 1, guidance_scale = 3.5, num_inference_steps = 28, seed = options.Seed ?? Random.Shared.Next() } }, _ => throw new NotSupportedException($"Model {model} not supported") }; } private async Task WaitForCompletionAsync(string predictionId, int timeoutMinutes) { var timeout = TimeSpan.FromMinutes(timeoutMinutes); var pollInterval = TimeSpan.FromSeconds(3); var startTime = DateTime.UtcNow; while (DateTime.UtcNow - startTime < timeout) { var status = await GetPredictionStatusAsync(predictionId); switch (status.Status.ToLower()) { case "succeeded": return new ReplicateImageResult { Success = true, ImageUrl = ExtractImageUrl(status.Output), Cost = CalculateReplicateCost(status.Metrics), ModelVersion = status.Version, Metadata = status.Metrics }; case "failed": return new ReplicateImageResult { Success = false, Error = status.Error ?? "Generation failed with unknown error" }; case "processing": case "starting": _logger.LogDebug("Replicate prediction {PredictionId} still processing", predictionId); await Task.Delay(pollInterval); break; default: _logger.LogWarning("Unknown prediction status: {Status}", status.Status); await Task.Delay(pollInterval); break; } } return new ReplicateImageResult { Success = false, Error = "Generation timeout exceeded" }; } private decimal CalculateReplicateCost(Dictionary? metrics) { // 從配置中獲取預設成本 if (_options.Models.TryGetValue(_options.DefaultModel, out var modelConfig)) { return modelConfig.CostPerGeneration; } return 0.025m; // 預設 Ideogram 成本 } private string? ExtractImageUrl(JsonElement? output) { if (!output.HasValue || output.Value.ValueKind == JsonValueKind.Null) return null; try { var element = output.Value; // 如果是陣列格式: ["http://..."] if (element.ValueKind == JsonValueKind.Array && element.GetArrayLength() > 0) { return element[0].GetString(); } // 如果是字串格式: "http://..." if (element.ValueKind == JsonValueKind.String) { return element.GetString(); } // 如果是物件格式: { "url": "http://..." } if (element.ValueKind == JsonValueKind.Object) { if (element.TryGetProperty("url", out var urlElement)) { return urlElement.GetString(); } // 或者其他可能的屬性名稱 if (element.TryGetProperty("image", out var imageElement)) { return imageElement.GetString(); } } _logger.LogWarning("Unknown output format: {OutputKind}", element.ValueKind); return null; } catch (Exception ex) { _logger.LogError(ex, "Failed to extract image URL from output"); return null; } } } // Response models for ReplicateService public class ReplicateImageResult { public bool Success { get; set; } public string? ImageUrl { get; set; } public decimal Cost { get; set; } public int ProcessingTimeMs { get; set; } public string? ModelVersion { get; set; } public string? Error { get; set; } public Dictionary? Metadata { get; set; } } public class ReplicateGenerationOptions { public int? Width { get; set; } public int? Height { get; set; } public int? Seed { get; set; } public int TimeoutMinutes { get; set; } = 5; }