290 lines
10 KiB
C#
290 lines
10 KiB
C#
using DramaLing.Api.Models.Configuration;
|
|
using DramaLing.Api.Models.DTOs;
|
|
using Microsoft.Extensions.Options;
|
|
using System.Diagnostics;
|
|
using System.Text;
|
|
using System.Text.Json;
|
|
|
|
namespace DramaLing.Api.Services;
|
|
|
|
public interface IReplicateService
|
|
{
|
|
Task<ReplicateImageResult> GenerateImageAsync(string prompt, string model, ReplicateGenerationOptions options);
|
|
Task<ReplicatePredictionStatus> GetPredictionStatusAsync(string predictionId);
|
|
}
|
|
|
|
public class ReplicateService : IReplicateService
|
|
{
|
|
private readonly HttpClient _httpClient;
|
|
private readonly ILogger<ReplicateService> _logger;
|
|
private readonly ReplicateOptions _options;
|
|
|
|
public ReplicateService(HttpClient httpClient, IOptions<ReplicateOptions> options, ILogger<ReplicateService> logger)
|
|
{
|
|
_httpClient = httpClient ?? throw new ArgumentNullException(nameof(httpClient));
|
|
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
|
|
_options = options.Value ?? throw new ArgumentNullException(nameof(options));
|
|
|
|
_logger.LogInformation("ReplicateService initialized with default model: {Model}, timeout: {Timeout}s",
|
|
_options.DefaultModel, _options.TimeoutSeconds);
|
|
|
|
_httpClient.Timeout = TimeSpan.FromSeconds(_options.TimeoutSeconds);
|
|
_httpClient.DefaultRequestHeaders.Add("Authorization", $"Token {_options.ApiKey}");
|
|
_httpClient.DefaultRequestHeaders.Add("User-Agent", "DramaLing/1.0");
|
|
_httpClient.DefaultRequestHeaders.Add("Prefer", "wait"); // 添加你使用的 header
|
|
}
|
|
|
|
public async Task<ReplicateImageResult> GenerateImageAsync(string prompt, string model, ReplicateGenerationOptions options)
|
|
{
|
|
var stopwatch = Stopwatch.StartNew();
|
|
|
|
try
|
|
{
|
|
_logger.LogInformation("Starting Replicate image generation with model {Model}", model);
|
|
|
|
// 啟動 Replicate 預測
|
|
var prediction = await StartPredictionAsync(prompt, model, options);
|
|
|
|
// 輪詢檢查生成狀態
|
|
var result = await WaitForCompletionAsync(prediction.Id, options.TimeoutMinutes);
|
|
|
|
result.ProcessingTimeMs = (int)stopwatch.ElapsedMilliseconds;
|
|
|
|
_logger.LogInformation("Replicate image generation completed in {ElapsedMs}ms", stopwatch.ElapsedMilliseconds);
|
|
|
|
return result;
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
stopwatch.Stop();
|
|
_logger.LogError(ex, "Replicate image generation failed");
|
|
|
|
return new ReplicateImageResult
|
|
{
|
|
Success = false,
|
|
Error = ex.Message,
|
|
ProcessingTimeMs = (int)stopwatch.ElapsedMilliseconds
|
|
};
|
|
}
|
|
}
|
|
|
|
public async Task<ReplicatePredictionStatus> GetPredictionStatusAsync(string predictionId)
|
|
{
|
|
try
|
|
{
|
|
var response = await _httpClient.GetAsync($"{_options.BaseUrl}/predictions/{predictionId}");
|
|
response.EnsureSuccessStatusCode();
|
|
|
|
var json = await response.Content.ReadAsStringAsync();
|
|
|
|
// 記錄實際收到的 JSON 格式用於除錯
|
|
_logger.LogDebug("Replicate API response for prediction {PredictionId}: {Response}",
|
|
predictionId, json.Substring(0, Math.Min(500, json.Length)));
|
|
|
|
var prediction = JsonSerializer.Deserialize<ReplicatePrediction>(json);
|
|
|
|
return new ReplicatePredictionStatus
|
|
{
|
|
Status = prediction?.Status ?? "unknown",
|
|
Output = prediction?.Output,
|
|
Error = prediction?.Error,
|
|
Version = prediction?.Version,
|
|
Metrics = prediction?.Metrics,
|
|
CompletedAt = prediction?.CompletedAt
|
|
};
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError(ex, "Failed to get prediction status for {PredictionId}", predictionId);
|
|
throw;
|
|
}
|
|
}
|
|
|
|
private async Task<ReplicatePrediction> StartPredictionAsync(string prompt, string model, ReplicateGenerationOptions options)
|
|
{
|
|
var requestBody = BuildModelRequest(prompt, model, options);
|
|
var apiUrl = GetModelApiUrl(model);
|
|
|
|
var json = JsonSerializer.Serialize(requestBody);
|
|
var content = new StringContent(json, Encoding.UTF8, "application/json");
|
|
|
|
_logger.LogDebug("Replicate API request to {ApiUrl}", apiUrl);
|
|
|
|
var response = await _httpClient.PostAsync(apiUrl, content);
|
|
response.EnsureSuccessStatusCode();
|
|
|
|
var responseJson = await response.Content.ReadAsStringAsync();
|
|
var prediction = JsonSerializer.Deserialize<ReplicatePrediction>(responseJson);
|
|
|
|
if (prediction == null)
|
|
{
|
|
throw new InvalidOperationException("Failed to parse Replicate prediction response");
|
|
}
|
|
|
|
return prediction;
|
|
}
|
|
|
|
private string GetModelApiUrl(string model)
|
|
{
|
|
return model.ToLower() switch
|
|
{
|
|
"ideogram-v2a-turbo" => "https://api.replicate.com/v1/models/ideogram-ai/ideogram-v2a-turbo/predictions",
|
|
_ => $"{_options.BaseUrl}/predictions"
|
|
};
|
|
}
|
|
|
|
private object BuildModelRequest(string prompt, string model, ReplicateGenerationOptions options)
|
|
{
|
|
// 使用你確認可行的簡化格式
|
|
return model.ToLower() switch
|
|
{
|
|
"ideogram-v2a-turbo" => new
|
|
{
|
|
input = new
|
|
{
|
|
prompt = prompt,
|
|
aspect_ratio = "1:1" // 簡化為你確認可行的格式
|
|
}
|
|
},
|
|
"flux-1-dev" => new
|
|
{
|
|
input = new
|
|
{
|
|
prompt = prompt,
|
|
width = 512,
|
|
height = 512,
|
|
num_outputs = 1,
|
|
guidance_scale = 3.5,
|
|
num_inference_steps = 28,
|
|
seed = options.Seed ?? Random.Shared.Next()
|
|
}
|
|
},
|
|
_ => throw new NotSupportedException($"Model {model} not supported")
|
|
};
|
|
}
|
|
|
|
private async Task<ReplicateImageResult> WaitForCompletionAsync(string predictionId, int timeoutMinutes)
|
|
{
|
|
var timeout = TimeSpan.FromMinutes(timeoutMinutes);
|
|
var pollInterval = TimeSpan.FromSeconds(3);
|
|
var startTime = DateTime.UtcNow;
|
|
|
|
while (DateTime.UtcNow - startTime < timeout)
|
|
{
|
|
var status = await GetPredictionStatusAsync(predictionId);
|
|
|
|
switch (status.Status.ToLower())
|
|
{
|
|
case "succeeded":
|
|
return new ReplicateImageResult
|
|
{
|
|
Success = true,
|
|
ImageUrl = ExtractImageUrl(status.Output),
|
|
Cost = CalculateReplicateCost(status.Metrics),
|
|
ModelVersion = status.Version,
|
|
Metadata = status.Metrics
|
|
};
|
|
|
|
case "failed":
|
|
return new ReplicateImageResult
|
|
{
|
|
Success = false,
|
|
Error = status.Error ?? "Generation failed with unknown error"
|
|
};
|
|
|
|
case "processing":
|
|
case "starting":
|
|
_logger.LogDebug("Replicate prediction {PredictionId} still processing", predictionId);
|
|
await Task.Delay(pollInterval);
|
|
break;
|
|
|
|
default:
|
|
_logger.LogWarning("Unknown prediction status: {Status}", status.Status);
|
|
await Task.Delay(pollInterval);
|
|
break;
|
|
}
|
|
}
|
|
|
|
return new ReplicateImageResult
|
|
{
|
|
Success = false,
|
|
Error = "Generation timeout exceeded"
|
|
};
|
|
}
|
|
|
|
private decimal CalculateReplicateCost(Dictionary<string, object>? metrics)
|
|
{
|
|
// 從配置中獲取預設成本
|
|
if (_options.Models.TryGetValue(_options.DefaultModel, out var modelConfig))
|
|
{
|
|
return modelConfig.CostPerGeneration;
|
|
}
|
|
|
|
return 0.025m; // 預設 Ideogram 成本
|
|
}
|
|
|
|
private string? ExtractImageUrl(JsonElement? output)
|
|
{
|
|
if (!output.HasValue || output.Value.ValueKind == JsonValueKind.Null)
|
|
return null;
|
|
|
|
try
|
|
{
|
|
var element = output.Value;
|
|
|
|
// 如果是陣列格式: ["http://..."]
|
|
if (element.ValueKind == JsonValueKind.Array && element.GetArrayLength() > 0)
|
|
{
|
|
return element[0].GetString();
|
|
}
|
|
|
|
// 如果是字串格式: "http://..."
|
|
if (element.ValueKind == JsonValueKind.String)
|
|
{
|
|
return element.GetString();
|
|
}
|
|
|
|
// 如果是物件格式: { "url": "http://..." }
|
|
if (element.ValueKind == JsonValueKind.Object)
|
|
{
|
|
if (element.TryGetProperty("url", out var urlElement))
|
|
{
|
|
return urlElement.GetString();
|
|
}
|
|
// 或者其他可能的屬性名稱
|
|
if (element.TryGetProperty("image", out var imageElement))
|
|
{
|
|
return imageElement.GetString();
|
|
}
|
|
}
|
|
|
|
_logger.LogWarning("Unknown output format: {OutputKind}", element.ValueKind);
|
|
return null;
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError(ex, "Failed to extract image URL from output");
|
|
return null;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Response models for ReplicateService
|
|
public class ReplicateImageResult
|
|
{
|
|
public bool Success { get; set; }
|
|
public string? ImageUrl { get; set; }
|
|
public decimal Cost { get; set; }
|
|
public int ProcessingTimeMs { get; set; }
|
|
public string? ModelVersion { get; set; }
|
|
public string? Error { get; set; }
|
|
public Dictionary<string, object>? Metadata { get; set; }
|
|
}
|
|
|
|
public class ReplicateGenerationOptions
|
|
{
|
|
public int? Width { get; set; }
|
|
public int? Height { get; set; }
|
|
public int? Seed { get; set; }
|
|
public int TimeoutMinutes { get; set; } = 5;
|
|
} |