diff --git a/OrtForge.AI.Agent.TestApp/OrtForge.AI.Agent.TestApp.csproj b/OrtForge.AI.Agent.TestApp/OrtForge.AI.Agent.TestApp.csproj new file mode 100644 index 0000000..e94befd --- /dev/null +++ b/OrtForge.AI.Agent.TestApp/OrtForge.AI.Agent.TestApp.csproj @@ -0,0 +1,32 @@ + + + Exe + net8.0 + enable + enable + latest + + MigraphX + + + + + + + + + + + + + + + + + + + + + + + diff --git a/OrtForge.AI.Agent.TestApp/PerformanceTestRunner.cs b/OrtForge.AI.Agent.TestApp/PerformanceTestRunner.cs new file mode 100644 index 0000000..7401e75 --- /dev/null +++ b/OrtForge.AI.Agent.TestApp/PerformanceTestRunner.cs @@ -0,0 +1,447 @@ +using System.Diagnostics; +using System.Text.Json; +using OrtForge.AI.Agent.Agents; +using OrtForge.AI.Agent.Generation; +using OrtForge.AI.Agent.LLM; +using OrtForge.AI.Agent.Tokenization; + +namespace OrtForge.AI.Agent.TestApp; + +public sealed class PerformanceTestRunner +{ + public record TestResult( + string Prompt, + string Category, + double TimeToFirstTokenMs, + double TokensPerSecond, + int TotalTokens, + double TotalTimeMs, + string Response, + bool HitMaxTokens, + bool StoppedNaturally); + + public record BenchmarkSummary( + string ModelPath, + string ConfigName, + double AverageTimeToFirstTokenMs, + double AverageTokensPerSecond, + int TotalPrompts, + double TotalDurationMs, + List Results); + + private static readonly Dictionary TestPrompts = new() + { + ["Factual"] = + [ + "What is the capital of France?", + "What is 2 + 2?", + "How many days are in a week?", + "What color is the sky?", + "Who wrote Romeo and Juliet?" + ], + ["Math"] = + [ + "What is 15 multiplied by 7?", + "If I have 3 apples and buy 5 more, how many do I have?", + "What is the next number: 2, 4, 6, 8, ?", + "Is 17 a prime number? Answer yes or no." + ], + ["Coding"] = + [ + "Write a Python function that adds two numbers.", + "Write hello world in JavaScript.", + "What does 'print' do in Python?" + ], + ["Creative"] = + [ + "Write one sentence about the ocean.", + "Name three colors.", + "Complete: The quick brown fox..." + ] + }; + + private static readonly string[][] MultiTurnConversation = + [ + ["My name is Alice.", "What is my name?", "Tell me a joke."] + ]; + + private readonly LlamaSession _llm; + private readonly TokenizerService _tokenizer; + private readonly string _modelPath; + + public PerformanceTestRunner(LlamaSession llm, TokenizerService tokenizer, string modelPath) + { + _llm = llm; + _tokenizer = tokenizer; + _modelPath = modelPath; + } + + public async Task RunBenchmarksAsync( + InferenceConfig config, + string configName = "Default", + CancellationToken cancellationToken = default) + { + var results = new List(); + var overallStopwatch = Stopwatch.StartNew(); + + Console.WriteLine(); + Console.WriteLine("╔══════════════════════════════════════════════════════════════╗"); + Console.WriteLine("║ OrtForge.AI Inference Benchmark ║"); + Console.WriteLine("╚══════════════════════════════════════════════════════════════╝"); + Console.WriteLine(); + Console.WriteLine($" Model: {Path.GetFileName(_modelPath)}"); + Console.WriteLine($" Config: {configName} (Temp={config.Temperature}, TopK={config.TopK}, TopP={config.TopP})"); + Console.WriteLine($" Model Type: {_llm.ModelType}"); + Console.WriteLine($" Max Tokens: {config.MaxTokens}"); + Console.WriteLine($" Stop Token IDs: [{string.Join(", ", config.StopTokenIds)}]"); + Console.WriteLine(); + + // Run single-turn tests + foreach (var (category, prompts) in TestPrompts) + { + Console.WriteLine($"┌─ Category: {category} ─────────────────────────────────────────┐"); + + foreach (var prompt in prompts) + { + if (cancellationToken.IsCancellationRequested) + break; + + var result = await RunSinglePromptAsync(prompt, category, config, cancellationToken); + results.Add(result); + PrintResult(result); + } + + Console.WriteLine("└──────────────────────────────────────────────────────────────┘"); + Console.WriteLine(); + } + + // Run multi-turn conversation test + Console.WriteLine("┌─ Category: Multi-turn ────────────────────────────────────────┐"); + var multiTurnResults = await RunMultiTurnTestAsync(MultiTurnConversation[0], config, cancellationToken); + foreach (var result in multiTurnResults) + { + results.Add(result); + PrintResult(result); + } + Console.WriteLine("└──────────────────────────────────────────────────────────────┘"); + Console.WriteLine(); + + overallStopwatch.Stop(); + + var summary = new BenchmarkSummary( + ModelPath: _modelPath, + ConfigName: configName, + AverageTimeToFirstTokenMs: results.Count > 0 ? results.Average(r => r.TimeToFirstTokenMs) : 0, + AverageTokensPerSecond: results.Count > 0 ? results.Average(r => r.TokensPerSecond) : 0, + TotalPrompts: results.Count, + TotalDurationMs: overallStopwatch.Elapsed.TotalMilliseconds, + Results: results); + + PrintSummary(summary); + + return summary; + } + + private async Task RunSinglePromptAsync( + string prompt, + string category, + InferenceConfig config, + CancellationToken cancellationToken) + { + using var session = new ConversationSession(_llm, _tokenizer, config); + var agent = new AgentOrchestrator(); + + var stopwatch = Stopwatch.StartNew(); + var firstTokenTime = TimeSpan.Zero; + var tokenCount = 0; + var response = new System.Text.StringBuilder(); + var isFirstToken = true; + + await foreach (var token in agent.ChatTurnAsync(session, prompt, cancellationToken: cancellationToken)) + { + if (isFirstToken) + { + firstTokenTime = stopwatch.Elapsed; + isFirstToken = false; + } + tokenCount++; + response.Append(token); + } + + stopwatch.Stop(); + + var totalTimeMs = stopwatch.Elapsed.TotalMilliseconds; + var generationTimeMs = totalTimeMs - firstTokenTime.TotalMilliseconds; + var tokensPerSecond = generationTimeMs > 0 && tokenCount > 1 + ? (tokenCount - 1) / (generationTimeMs / 1000.0) + : 0; + + var hitMaxTokens = tokenCount >= config.MaxTokens; + var stoppedNaturally = !hitMaxTokens && tokenCount > 0; + + return new TestResult( + Prompt: prompt, + Category: category, + TimeToFirstTokenMs: firstTokenTime.TotalMilliseconds, + TokensPerSecond: tokensPerSecond, + TotalTokens: tokenCount, + TotalTimeMs: totalTimeMs, + Response: response.ToString().Trim(), + HitMaxTokens: hitMaxTokens, + StoppedNaturally: stoppedNaturally); + } + + private async Task> RunMultiTurnTestAsync( + string[] turns, + InferenceConfig config, + CancellationToken cancellationToken) + { + var results = new List(); + using var session = new ConversationSession(_llm, _tokenizer, config); + var agent = new AgentOrchestrator(); + + for (int i = 0; i < turns.Length; i++) + { + var prompt = turns[i]; + var stopwatch = Stopwatch.StartNew(); + var firstTokenTime = TimeSpan.Zero; + var tokenCount = 0; + var response = new System.Text.StringBuilder(); + var isFirstToken = true; + + await foreach (var token in agent.ChatTurnAsync(session, prompt, cancellationToken: cancellationToken)) + { + if (isFirstToken) + { + firstTokenTime = stopwatch.Elapsed; + isFirstToken = false; + } + tokenCount++; + response.Append(token); + } + + stopwatch.Stop(); + + var totalTimeMs = stopwatch.Elapsed.TotalMilliseconds; + var generationTimeMs = totalTimeMs - firstTokenTime.TotalMilliseconds; + var tokensPerSecond = generationTimeMs > 0 && tokenCount > 1 + ? (tokenCount - 1) / (generationTimeMs / 1000.0) + : 0; + + var hitMaxTokens = tokenCount >= config.MaxTokens; + var stoppedNaturally = !hitMaxTokens && tokenCount > 0; + + results.Add(new TestResult( + Prompt: $"[Turn {i + 1}] {prompt}", + Category: "Multi-turn", + TimeToFirstTokenMs: firstTokenTime.TotalMilliseconds, + TokensPerSecond: tokensPerSecond, + TotalTokens: tokenCount, + TotalTimeMs: totalTimeMs, + Response: response.ToString().Trim(), + HitMaxTokens: hitMaxTokens, + StoppedNaturally: stoppedNaturally)); + } + + return results; + } + + private static void PrintResult(TestResult result) + { + var promptDisplay = result.Prompt.Length > 45 + ? result.Prompt[..42] + "..." + : result.Prompt; + + var stopStatus = result.StoppedNaturally ? "EOS" : (result.HitMaxTokens ? "MAX" : "???"); + + Console.WriteLine($"│ \"{promptDisplay}\""); + Console.WriteLine($"│ TTFT: {result.TimeToFirstTokenMs,7:F1}ms | TPS: {result.TokensPerSecond,6:F1} | Tokens: {result.TotalTokens,4} | Stop: {stopStatus} | Total: {result.TotalTimeMs,7:F0}ms"); + + // Show full response (sanitized) + var fullResponse = SanitizeForDisplay(result.Response); + + // Word wrap at ~70 chars for readability + var lines = WordWrap(fullResponse, 68); + Console.WriteLine($"│ Response:"); + foreach (var line in lines) + { + Console.WriteLine($"│ {line}"); + } + Console.WriteLine("│"); + } + + private static List WordWrap(string text, int maxWidth) + { + var lines = new List(); + if (string.IsNullOrEmpty(text)) + { + lines.Add("(empty)"); + return lines; + } + + var words = text.Split(' ', StringSplitOptions.RemoveEmptyEntries); + var currentLine = new System.Text.StringBuilder(); + + foreach (var word in words) + { + if (currentLine.Length + word.Length + 1 > maxWidth) + { + if (currentLine.Length > 0) + { + lines.Add(currentLine.ToString()); + currentLine.Clear(); + } + } + + if (currentLine.Length > 0) + currentLine.Append(' '); + currentLine.Append(word); + } + + if (currentLine.Length > 0) + lines.Add(currentLine.ToString()); + + return lines; + } + + private static string SanitizeForDisplay(string text) + { + if (string.IsNullOrEmpty(text)) + return "(empty)"; + + // Remove common special tokens + var sanitized = text + .Replace("<|begin_of_text|>", "") + .Replace("<|end_of_text|>", "") + .Replace("<|start_header_id|>", "") + .Replace("<|end_header_id|>", "") + .Replace("<|eot_id|>", "") + .Replace("<|im_start|>", "") + .Replace("<|im_end|>", ""); + + // Remove control characters and normalize whitespace + var chars = new System.Text.StringBuilder(sanitized.Length); + foreach (var c in sanitized) + { + if (char.IsControl(c) || c == '\r' || c == '\n' || c == '\t') + { + chars.Append(' '); + } + else if (char.IsHighSurrogate(c) || char.IsLowSurrogate(c)) + { + // Skip unpaired surrogates that might render as garbage + continue; + } + else if (c >= 0x4E00 && c <= 0x9FFF) + { + // Skip CJK characters that are likely tokenizer artifacts (like 醴) + continue; + } + else + { + chars.Append(c); + } + } + + // Collapse multiple spaces + var result = System.Text.RegularExpressions.Regex.Replace(chars.ToString().Trim(), @"\s+", " "); + return string.IsNullOrWhiteSpace(result) ? "(special tokens only)" : result; + } + + private static void PrintSummary(BenchmarkSummary summary) + { + var stoppedNaturally = summary.Results.Count(r => r.StoppedNaturally); + var hitMaxTokens = summary.Results.Count(r => r.HitMaxTokens); + + Console.WriteLine("╔══════════════════════════════════════════════════════════════╗"); + Console.WriteLine("║ SUMMARY ║"); + Console.WriteLine("╠══════════════════════════════════════════════════════════════╣"); + Console.WriteLine($"║ Average TTFT: {summary.AverageTimeToFirstTokenMs,8:F1} ms ║"); + Console.WriteLine($"║ Average TPS: {summary.AverageTokensPerSecond,8:F1} tokens/sec ║"); + Console.WriteLine($"║ Total Prompts: {summary.TotalPrompts,8} ║"); + Console.WriteLine($"║ Stopped (EOS): {stoppedNaturally,8} ║"); + Console.WriteLine($"║ Hit Max Tokens: {hitMaxTokens,8} ║"); + Console.WriteLine($"║ Total Duration: {summary.TotalDurationMs / 1000.0,8:F2} sec ║"); + Console.WriteLine("╚══════════════════════════════════════════════════════════════╝"); + + if (hitMaxTokens > stoppedNaturally) + { + Console.WriteLine(); + Console.WriteLine("⚠️ WARNING: Most responses hit max token limit without natural stop."); + Console.WriteLine(" This may indicate:"); + Console.WriteLine(" 1. Model is a BASE model (not instruction-tuned)"); + Console.WriteLine(" 2. Stop token IDs are incorrect for this model"); + Console.WriteLine(" 3. Chat template doesn't match model's training format"); + Console.WriteLine(); + Console.WriteLine(" Verify your model is 'Meta-Llama-3.1-8B-Instruct' (not base)"); + } + } + + public static void ExportToJson(BenchmarkSummary summary, string filePath) + { + var options = new JsonSerializerOptions + { + WriteIndented = true, + PropertyNamingPolicy = JsonNamingPolicy.CamelCase + }; + var json = JsonSerializer.Serialize(summary, options); + File.WriteAllText(filePath, json); + Console.WriteLine($"\nResults exported to: {filePath}"); + } + + public static async Task RunComparisonBenchmarkAsync( + LlamaSession llm, + TokenizerService tokenizer, + string modelPath, + string? exportPath = null, + int? maxTokens = null, + CancellationToken cancellationToken = default) + { + var runner = new PerformanceTestRunner(llm, tokenizer, modelPath); + + var configs = new Dictionary + { + ["Greedy"] = InferenceConfig.Greedy, + ["Default"] = InferenceConfig.Default, + ["Precise"] = InferenceConfig.Precise + }; + + var allSummaries = new List(); + + foreach (var (name, config) in configs) + { + if (cancellationToken.IsCancellationRequested) + break; + + // Merge with model-specific optimal config + var mergedConfig = LlamaOptimizations.GetOptimalConfigForModel(llm.ModelType, config); + if (maxTokens.HasValue) + { + mergedConfig = mergedConfig with { MaxTokens = maxTokens.Value }; + } + var summary = await runner.RunBenchmarksAsync(mergedConfig, name, cancellationToken); + allSummaries.Add(summary); + + Console.WriteLine("\nPress any key to continue to next config (or Ctrl+C to stop)...\n"); + if (Console.KeyAvailable) + Console.ReadKey(true); + } + + if (!string.IsNullOrEmpty(exportPath)) + { + var combinedPath = Path.Combine( + Path.GetDirectoryName(exportPath) ?? ".", + $"benchmark_comparison_{DateTime.Now:yyyyMMdd_HHmmss}.json"); + + var options = new JsonSerializerOptions + { + WriteIndented = true, + PropertyNamingPolicy = JsonNamingPolicy.CamelCase + }; + var json = JsonSerializer.Serialize(allSummaries, options); + File.WriteAllText(combinedPath, json); + Console.WriteLine($"\nComparison results exported to: {combinedPath}"); + } + } +} + diff --git a/OrtForge.AI.Agent.TestApp/Program.cs b/OrtForge.AI.Agent.TestApp/Program.cs new file mode 100644 index 0000000..fdcc9e6 --- /dev/null +++ b/OrtForge.AI.Agent.TestApp/Program.cs @@ -0,0 +1,245 @@ +using OrtForge.AI.Agent.Agents; +using OrtForge.AI.Agent.Generation; +using OrtForge.AI.Agent.LLM; +using OrtForge.AI.Agent.Runtime; +using OrtForge.AI.Agent.Tokenization; + +namespace OrtForge.AI.Agent.TestApp; + +internal static class Program +{ + private static async Task Main(string[] args) + { + // Check for --help flag + if (args.Length == 0 || args.Contains("--help") || args.Contains("-h")) + { + PrintUsage(); + return; + } + + // Parse arguments + var benchmarkMode = args.Contains("--benchmark"); + var compareMode = args.Contains("--compare"); + var jsonExport = args.FirstOrDefault(a => a.StartsWith("--export="))?.Replace("--export=", ""); + var configArg = args.FirstOrDefault(a => a.StartsWith("--config="))?.Replace("--config=", ""); + var maxTokensArg = args.FirstOrDefault(a => a.StartsWith("--max-tokens="))?.Replace("--max-tokens=", ""); + int? maxTokens = int.TryParse(maxTokensArg, out var mt) ? mt : null; // null = use config default (no override) + var debugPrompts = args.Contains("--debug-prompts"); + + // Filter out flags to get positional arguments + var positionalArgs = args.Where(a => !a.StartsWith("--") && !a.StartsWith("-")).ToArray(); + + if (positionalArgs.Length < 2) + { + Console.WriteLine("Error: Missing required arguments."); + PrintUsage(); + return; + } + + var llmPath = positionalArgs[0].Trim(); + var tokenizerPath = positionalArgs[1].Trim(); + + Console.WriteLine($"LLM: {llmPath}"); + Console.WriteLine($"Tokenizer: {tokenizerPath}"); + + using var llmSession = OrtRuntimeFactory.CreateSession(llmPath); + var modelType = ModelTypeExtensions.ParseFromString(llmPath); + Console.WriteLine($"Detected model type: {modelType}"); + + // Debug: Print model inputs/outputs + if (args.Contains("--debug-model")) + { + Console.WriteLine("\n=== MODEL METADATA ==="); + Console.WriteLine("Inputs:"); + foreach (var input in llmSession.InputMetadata) + { + var dims = string.Join(", ", input.Value.Dimensions); + Console.WriteLine($" {input.Key}: [{dims}] ({input.Value.ElementDataType})"); + } + Console.WriteLine("\nOutputs:"); + foreach (var output in llmSession.OutputMetadata) + { + var dims = string.Join(", ", output.Value.Dimensions); + Console.WriteLine($" {output.Key}: [{dims}] ({output.Value.ElementDataType})"); + } + Console.WriteLine("=== END MODEL METADATA ===\n"); + } + + using var llama = new LlamaSession(llmSession, modelType); + var tok = TokenizerService.FromHuggingFace(tokenizerPath); + + // Run benchmark mode + if (benchmarkMode || compareMode) + { + await RunBenchmarkMode(llama, tok, llmPath, compareMode, jsonExport, configArg, maxTokens, debugPrompts); + return; + } + + // Interactive chat mode + await RunInteractiveMode(llama, tok); + } + + private static void PrintUsage() + { + Console.WriteLine(@" +OrtForge.AI TestApp - LLM Inference Testing + +Usage: + OrtForge.AI.Agent.TestApp [options] + +Arguments: + llm.onnx Path to the ONNX model file + tokenizer.json Path to the tokenizer file (HuggingFace JSON or SentencePiece BPE) + +Options: + --benchmark Run performance benchmark with predefined prompts + --compare Run benchmarks with multiple inference configs (Greedy, Default, Precise) + --config= Specify config for benchmark: Greedy, Default, Precise, Creative + --max-tokens= Maximum tokens to generate per response (default: 128 for benchmarks) + --debug-prompts Show the chat template format being used + --export= Export benchmark results to JSON file + --help, -h Show this help message + +Examples: + # Interactive chat mode + OrtForge.AI.Agent.TestApp model.onnx tokenizer.json + + # Run benchmark with default config + OrtForge.AI.Agent.TestApp model.onnx tokenizer.json --benchmark + + # Run benchmark with limited tokens for quick testing + OrtForge.AI.Agent.TestApp model.onnx tokenizer.json --benchmark --max-tokens=64 + + # Run benchmark with Greedy config and export results + OrtForge.AI.Agent.TestApp model.onnx tokenizer.json --benchmark --config=Greedy --export=results.json + + # Compare all configs + OrtForge.AI.Agent.TestApp model.onnx tokenizer.json --compare +"); + } + + private static async Task RunBenchmarkMode( + LlamaSession llama, + TokenizerService tok, + string llmPath, + bool compareMode, + string? jsonExport, + string? configArg, + int? maxTokens, + bool debugPrompts) + { + if (compareMode) + { + await PerformanceTestRunner.RunComparisonBenchmarkAsync(llama, tok, llmPath, jsonExport, maxTokens); + return; + } + + // Single config benchmark + var config = GetConfigByName(configArg ?? "Default"); + var configName = configArg ?? "Default"; + + // Merge with model-specific optimal config + var mergedConfig = LlamaOptimizations.GetOptimalConfigForModel(llama.ModelType, config); + if (maxTokens.HasValue) + { + mergedConfig = mergedConfig with { MaxTokens = maxTokens.Value }; + } + + if (debugPrompts) + { + // Show a sample prompt for debugging + var samplePrompt = AgentOrchestrator.BuildSystemPrompt([], "What is 2+2?"); + Console.WriteLine(); + Console.WriteLine("=== DEBUG: Sample Prompt Format ==="); + Console.WriteLine(samplePrompt.Replace("\n", "\\n\n")); + Console.WriteLine("=== END DEBUG ==="); + Console.WriteLine(); + + // Show actual token IDs + var tokenIds = tok.EncodeToIds(samplePrompt); + Console.WriteLine("=== DEBUG: Token IDs (first 50) ==="); + Console.WriteLine($"Total tokens: {tokenIds.Length}"); + var first50 = tokenIds.Take(50).ToArray(); + Console.WriteLine($"IDs: [{string.Join(", ", first50)}]"); + + // Check if special tokens are recognized + var specialTokenTest = tok.EncodeToIds("<|begin_of_text|>"); + Console.WriteLine($"\n<|begin_of_text|> encodes to: [{string.Join(", ", specialTokenTest)}]"); + + var eotTest = tok.EncodeToIds("<|eot_id|>"); + Console.WriteLine($"<|eot_id|> encodes to: [{string.Join(", ", eotTest)}]"); + + var headerTest = tok.EncodeToIds("<|start_header_id|>system<|end_header_id|>"); + Console.WriteLine($"<|start_header_id|>system<|end_header_id|> encodes to: [{string.Join(", ", headerTest)}]"); + Console.WriteLine("=== END TOKEN DEBUG ==="); + Console.WriteLine(); + } + + var runner = new PerformanceTestRunner(llama, tok, llmPath); + var summary = await runner.RunBenchmarksAsync(mergedConfig, configName); + + if (!string.IsNullOrEmpty(jsonExport)) + { + PerformanceTestRunner.ExportToJson(summary, jsonExport); + } + } + + private static InferenceConfig GetConfigByName(string name) + { + return name.ToLowerInvariant() switch + { + "greedy" => InferenceConfig.Greedy, + "precise" => InferenceConfig.Precise, + "creative" => InferenceConfig.Creative, + _ => InferenceConfig.Default + }; + } + + private static async Task RunInteractiveMode(LlamaSession llama, TokenizerService tok) + { + var agent = new AgentOrchestrator(); + using var session = new ConversationSession(llama, tok, llama.OptimalConfig); + + Console.WriteLine("🤖 OrtForge.AI Chat"); + Console.WriteLine("💬 Enter your message (empty line to quit):"); + Console.WriteLine(); + + while (true) + { + Console.Write("🧑 > "); + var user = Console.ReadLine(); + if (string.IsNullOrWhiteSpace(user)) + { + Console.WriteLine("👋 Goodbye!"); + break; + } + + Console.WriteLine(); + Console.Write("🤖 Assistant: "); + + try + { + await foreach (var token in agent.ChatTurnAsync(session, user!)) + { + Console.Write(token); + } + } + catch (Exception ex) + { + Console.WriteLine(); + Console.WriteLine($"❌ Error: {ex.Message}"); + Console.WriteLine($"❌ Stack trace: {ex.StackTrace}"); + } + + Console.WriteLine(); + } + + Console.WriteLine("===============CHAT HISTORY================"); + Console.WriteLine(session.EntireConversation.ToString()); + Console.WriteLine("==========================================="); + Console.WriteLine("Press any key to exit..."); + Console.ReadKey(); + } +} + + diff --git a/OrtForge.AI.Agent/Agents/AgentOrchestrator.cs b/OrtForge.AI.Agent/Agents/AgentOrchestrator.cs new file mode 100644 index 0000000..5691b1a --- /dev/null +++ b/OrtForge.AI.Agent/Agents/AgentOrchestrator.cs @@ -0,0 +1,174 @@ +using System.Runtime.CompilerServices; +using System.Text; +using OrtForge.AI.Agent.Generation; +using OrtForge.AI.Agent.Rag; +using OrtForge.AI.Models.Models; + +namespace OrtForge.AI.Agent.Agents; + +public sealed class AgentOrchestrator +{ + private readonly BgeM3Model? _embeddings; + private readonly BgeRerankerM3? _reranker; + private readonly InMemoryVectorStore? _vec; + + public AgentOrchestrator(BgeM3Model? embeddings = null, InMemoryVectorStore? vec = null, BgeRerankerM3? reranker = null) + { + _embeddings = embeddings; + _reranker = reranker; + _vec = vec; + } + + public async IAsyncEnumerable ChatTurnAsync(ConversationSession session, string user, + Func? toolExecutor = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + List retrieved; + + if (_embeddings == null || _vec == null) + { + retrieved = []; + } + else + { + var queryVec = await _embeddings.CreateEmbeddingAsync(user, cancellationToken: cancellationToken); + var candidateResults = _vec.TopK(queryVec, 10).ToList(); + + retrieved = candidateResults.Select(x => x.Text).ToList(); + + if (_reranker != null && candidateResults.Count > 1) + { + var rerankedResults = new List<(float score, string text)>(); + foreach (var candidate in candidateResults) + { + var score = await _reranker.GetRerankingScoreAsync(user, candidate.Text, cancellationToken: cancellationToken); + rerankedResults.Add((score: score, text: candidate.Text)); + } + + retrieved = rerankedResults + .OrderByDescending(x => x.score) + .Take(5) + .Select(x => x.text) + .ToList(); + } + else + { + retrieved = retrieved.Take(5).ToList(); + } + } + + var prompt = !session.IsInitialized + ? BuildSystemPrompt(retrieved, user, toolExecutor != null) + : BuildChatTurnPrompt(retrieved, user, toolExecutor != null); + + await foreach (var token in session.GenerateNextResponseAsync(prompt, toolExecutor, cancellationToken)) + { + yield return token; + } + } + + /// + /// Efficiently merge user config with pre-computed optimal config + /// + private static InferenceConfig MergeConfigs(InferenceConfig optimalConfig, InferenceConfig userConfig) + { + return optimalConfig with + { + Temperature = userConfig.Temperature, + TopK = userConfig.TopK, + TopP = userConfig.TopP, + RepetitionPenalty = userConfig.RepetitionPenalty, + FrequencyPenalty = userConfig.FrequencyPenalty, + PresencePenalty = userConfig.PresencePenalty, + MaxTokens = userConfig.MaxTokens, + Seed = userConfig.Seed, + UseGreedy = userConfig.UseGreedy, + MinP = userConfig.MinP, + TfsZ = userConfig.TfsZ, + TypicalP = userConfig.TypicalP, + StopTokenIds = optimalConfig.StopTokenIds.Concat(userConfig.StopTokenIds).ToHashSet(), + StopSequences = optimalConfig.StopSequences.Concat(userConfig.StopSequences).ToArray() + }; + } + + public static bool IsStopToken(int tokenId, InferenceConfig config) => config.StopTokenIds.Contains(tokenId); + + public static bool IsStopSequence(string text, InferenceConfig config) + { + return config.StopSequences.Any(seq => text.Contains(seq)); + } + + public static string BuildSystemPrompt(IReadOnlyList retrieved, string firstUserMessage, bool enableTools = false) + { + var sb = new StringBuilder(); + + // Llama 3.1 chat template: blank line after header, eot_id on same line as content + sb.Append("<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"); + sb.Append("You are a helpful AI assistant. Answer questions accurately and concisely."); + sb.Append("<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"); + sb.Append(firstUserMessage); + + if (retrieved.Count > 0) + { + sb.AppendLine(); + sb.AppendLine("## Available Context:"); + for (int i = 0; i < retrieved.Count; i++) + { + sb.AppendLine($"**Source {i + 1}:** {retrieved[i]}"); + } + } + + if (enableTools) + { + sb.AppendLine(); + sb.AppendLine("## Tool Usage:"); + sb.AppendLine("When you need to use a tool, format it as:"); + sb.AppendLine("```"); + sb.AppendLine("TOOL_CALL"); + sb.AppendLine("name: tool_name"); + sb.AppendLine("args: tool_arguments"); + sb.AppendLine("END_TOOL_CALL"); + sb.AppendLine("```"); + } + + sb.Append("<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"); + + return sb.ToString(); + } + + public static string BuildChatTurnPrompt(IReadOnlyList retrieved, string user, bool enableTools = false) + { + var sb = new StringBuilder(); + + // Llama 3.1 chat template: blank line after header, eot_id on same line as content + sb.Append("<|start_header_id|>user<|end_header_id|>\n\n"); + sb.Append(user); + + if (retrieved.Count > 0) + { + sb.AppendLine(); + sb.AppendLine("## Available Context:"); + for (int i = 0; i < retrieved.Count; i++) + { + sb.AppendLine($"**Source {i + 1}:** {retrieved[i]}"); + } + } + + if (enableTools) + { + sb.AppendLine(); + sb.AppendLine("## Tool Usage:"); + sb.AppendLine("When you need to use a tool, format it as:"); + sb.AppendLine("```"); + sb.AppendLine("TOOL_CALL"); + sb.AppendLine("name: tool_name"); + sb.AppendLine("args: tool_arguments"); + sb.AppendLine("END_TOOL_CALL"); + sb.AppendLine("```"); + } + + sb.Append("<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"); + return sb.ToString(); + } +} + + diff --git a/OrtForge.AI.Agent/Agents/ConversationSession.cs b/OrtForge.AI.Agent/Agents/ConversationSession.cs new file mode 100644 index 0000000..9e89d48 --- /dev/null +++ b/OrtForge.AI.Agent/Agents/ConversationSession.cs @@ -0,0 +1,124 @@ +using System.Runtime.CompilerServices; +using System.Text; +using Microsoft.ML.Tokenizers; +using OrtForge.AI.Agent.Generation; +using OrtForge.AI.Agent.LLM; +using OrtForge.AI.Agent.Tokenization; + +namespace OrtForge.AI.Agent.Agents; + +public sealed class ConversationSession : IDisposable +{ + private readonly TokenizerService _tokenizer; + private readonly LlamaSession _llm; + private readonly InferenceConfig _inferenceConfig; + private KvState _kvState; + private bool _isSystemPromptProcessed; + private readonly TokenHistory _tokenHistory; + public StringBuilder EntireConversation { get; } = new(); + + public ConversationSession(LlamaSession llm, TokenizerService tokenizer, InferenceConfig inferenceConfig, int repetitionPenaltyWindowSize = 128) + { + _llm = llm; + _inferenceConfig = inferenceConfig; + _tokenizer = tokenizer; + _kvState = new KvState([]); + _tokenHistory = new TokenHistory(repetitionPenaltyWindowSize); + } + + public string SessionId { get; } = Guid.NewGuid().ToString("N")[..8]; + public bool IsInitialized => _isSystemPromptProcessed; + + public void Dispose() + { + _kvState.Dispose(); + } + + public async IAsyncEnumerable GenerateNextResponseAsync(string prompt, + Func? toolExecutor = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + EntireConversation.Append(prompt); + var toolState = new ToolCallState(); + var inputIds = _tokenizer.EncodeToIds(prompt).Select(x => (long)x).ToArray(); + var isFirstToken = true; + + for (int token = 0; token < _inferenceConfig.MaxTokens; token++) + { + using var outputs = + await _llm.RunOptimizedStepAsync(inputIds, _kvState, _kvState.AccumulatedSequenceLength + inputIds.Length, + cancellationToken); + + // Dispose previous KV state to prevent memory leak + var oldKvState = _kvState; + _kvState = outputs.KvCache; + oldKvState.Dispose(); + + // Use sliding window token history for repetition penalty + var nextToken = GetNextTokenSample(outputs, _tokenHistory.GetTokens()); + var tokenText = _tokenizer.DecodeFromIds([nextToken]); + EntireConversation.Append(tokenText); + + if (IsStopToken(nextToken)) + { + _isSystemPromptProcessed = true; + // Append the stop token text to conversation for proper multi-turn format + EntireConversation.Append("<|eot_id|>"); + yield break; + } + + // Add to sliding window for cross-turn repetition penalty + _tokenHistory.AddToken(nextToken); + + //inject current token into next inference step + inputIds = [nextToken]; + + if (toolExecutor != null) + { + toolState.AppendToken(tokenText); + var pendingCall = toolState.GetNextPendingCall(); + if (pendingCall != null) + { + //TODO + } + } + + // Mark session as initialized after first token generated + if (isFirstToken) + { + _isSystemPromptProcessed = true; + isFirstToken = false; + } + + yield return tokenText; + } + } + + private bool IsStopToken(int tokenId) => _inferenceConfig.StopTokenIds.Contains(tokenId); + private int GetNextTokenSample(LlamaSession.StepOutputs outputs, List previousTurnTokens) + { + var span = outputs.GetLogitsSpan(); + var logitsShape = outputs.Logits.GetTensorTypeAndShape().Shape; + Span logitsForSampling; + if (logitsShape.Length == 3) // [batch, seq_len, vocab] + { + var seqLen = (int)logitsShape[1]; + var vocabSize = (int)logitsShape[2]; + + var lastTokenStart = (seqLen - 1) * vocabSize; + logitsForSampling = span.Slice(lastTokenStart, vocabSize); + } + else if (logitsShape.Length == 2) // [batch, vocab] - generation step + { + var vocabSize = (int)logitsShape[1]; + + logitsForSampling = span.Slice(0, vocabSize); + } + else + { + throw new InvalidOperationException("Unexpected logits shape."); + } + + return Sampling.Sample(logitsForSampling, _inferenceConfig, previousTurnTokens); + } +} diff --git a/OrtForge.AI.Agent/Agents/ToolCall.cs b/OrtForge.AI.Agent/Agents/ToolCall.cs new file mode 100644 index 0000000..ad130ee --- /dev/null +++ b/OrtForge.AI.Agent/Agents/ToolCall.cs @@ -0,0 +1,162 @@ +namespace OrtForge.AI.Agent.Agents; + +public sealed record ToolCall( + string Name, + string Arguments, + string Id = "", + string? Result = null, + ToolCallStatus Status = ToolCallStatus.Pending, + string? Error = null +); + +public enum ToolCallStatus +{ + Pending, + Parsing, + Executing, + Completed, + Failed +} + +public sealed class ToolCallState +{ + private const int MaxBufferSize = 8192; // Limit buffer to prevent unbounded growth + + private readonly List _calls = []; + private string _currentBuffer = string.Empty; + private bool _inToolCall = false; + private int _toolCallStart = -1; + + public IReadOnlyList Calls => _calls; + public bool InToolCall => _inToolCall; + public bool HasPendingCalls => _calls.Exists(c => c.Status == ToolCallStatus.Pending); + + public void AppendToken(string token) + { + _currentBuffer += token; + TrimBufferIfNeeded(); + CheckForToolCallPatterns(); + } + + public void AppendText(string text) + { + _currentBuffer += text; + TrimBufferIfNeeded(); + CheckForToolCallPatterns(); + } + + private void TrimBufferIfNeeded() + { + // If buffer exceeds max size and we're not in a tool call, keep only the tail + if (_currentBuffer.Length > MaxBufferSize && !_inToolCall) + { + // Keep the last portion that might contain a partial TOOL_CALL marker + var keepSize = Math.Min(MaxBufferSize / 2, _currentBuffer.Length); + _currentBuffer = _currentBuffer.Substring(_currentBuffer.Length - keepSize); + } + } + + public ToolCall? GetNextPendingCall() + { + return _calls.Find(c => c.Status == ToolCallStatus.Pending); + } + + public void UpdateCallStatus(ToolCall call, ToolCallStatus status, string? result = null, string? error = null) + { + var index = _calls.FindIndex(c => c.Id == call.Id); + if (index >= 0) + { + _calls[index] = call with { Status = status, Result = result, Error = error }; + } + } + + public void Reset() + { + _calls.Clear(); + _currentBuffer = string.Empty; + _inToolCall = false; + _toolCallStart = -1; + } + + private const string StartMarker = "TOOL_CALL"; + private const string EndMarker = "END_TOOL_CALL"; + + private void CheckForToolCallPatterns() + { + // Keep checking for tool calls until no more complete ones are found + while (true) + { + if (!_inToolCall) + { + var startIndex = _currentBuffer.IndexOf(StartMarker, StringComparison.Ordinal); + if (startIndex >= 0) + { + _inToolCall = true; + _toolCallStart = startIndex; + } + else + { + break; // No more tool call starts found + } + } + + if (_inToolCall) + { + var endIndex = _currentBuffer.IndexOf(EndMarker, _toolCallStart, StringComparison.Ordinal); + if (endIndex >= 0) + { + // Extract content between TOOL_CALL and END_TOOL_CALL + var contentStart = _toolCallStart + StartMarker.Length; + var callContent = _currentBuffer.Substring(contentStart, endIndex - contentStart); + var toolCall = ParseToolCallContent(callContent); + if (toolCall != null) + { + _calls.Add(toolCall); + } + + // Remove processed content from buffer to allow finding next tool call + _currentBuffer = _currentBuffer.Substring(endIndex + EndMarker.Length); + _inToolCall = false; + _toolCallStart = -1; + } + else + { + break; // Incomplete tool call, wait for more tokens + } + } + } + } + + private static ToolCall? ParseToolCallContent(string content) + { + try + { + var lines = content.Trim().Split('\n', StringSplitOptions.RemoveEmptyEntries); + string? name = null; + string? args = null; + + foreach (var line in lines) + { + var trimmed = line.Trim(); + if (trimmed.StartsWith("name:", StringComparison.OrdinalIgnoreCase)) + { + name = trimmed.Substring(5).Trim(); + } + else if (trimmed.StartsWith("args:", StringComparison.OrdinalIgnoreCase)) + { + args = trimmed.Substring(5).Trim(); + } + } + + if (!string.IsNullOrEmpty(name)) + { + return new ToolCall(name, args ?? string.Empty, Guid.NewGuid().ToString()); + } + } + catch + { + } + + return null; + } +} diff --git a/OrtForge.AI.Agent/Docs/KV_Cache_best_practices.md b/OrtForge.AI.Agent/Docs/KV_Cache_best_practices.md new file mode 100644 index 0000000..2c14b87 --- /dev/null +++ b/OrtForge.AI.Agent/Docs/KV_Cache_best_practices.md @@ -0,0 +1,101 @@ +### What a KV cache is (for LLMs)~~~~ + +- In decoder-only Transformers (e.g., LLaMA, GPT), each attention layer computes Keys (K) and Values (V) for every processed token. +- During autoregressive generation, you produce tokens one by one. Without caching, each new token would force recomputation of K and V for the entire prefix at every step, which is expensive. +- KV cache stores the K and V tensors for already-processed tokens so the model only computes K/V for the new token and attends to the cached K/V for the past. This dramatically reduces per-step compute. + +In short: KV cache is the model’s per-layer memory of past attention states, enabling fast incremental decoding. + +### Why we need the KV cache + +- Performance: Avoid quadratic recomputation over the growing prefix at each step. +- Cost efficiency: Per-step cost becomes roughly linear in sequence length (or near-constant with paged attention implementations). +- UX: Enables responsive token streaming during generation. + +### How to interact with an LLM using KV cache (token-by-token) + +1. Tokenize the prompt to input_ids. +2. First pass (prefill): + - Inputs: input_ids = the prompt (length > 1), optional attention_mask and position_ids. + - No past_key_values yet. + - Outputs: logits (for next token) and present_key_values (the KV cache for the entire processed sequence). +3. Choose next token from logits (argmax/sampling/temperature/top-k/p). +4. Next step (incremental decoding): + - Inputs: input_ids = [the single new token], and past_key_values = the cache from the previous step; also attention_mask/position_ids if required. + - Outputs: new logits and updated present_key_values (prefix + new token). +5. Repeat step 3–4 until stopping (EOS token, length limit, etc.). + +This pattern allows you to “serve” the KV cache by feeding each step’s present_key_values back as the next step’s past_key_values for the same sequence. + +### Naming conventions you’ll see (LLaMA/ONNX) + +- Inputs often expect: input_ids, optional attention_mask, position_ids, and past_key_values.* (or past_* per layer and K/V). +- Outputs often provide: logits and present_key_values.* (or present_* variants). +- Between steps you map: present_* → past_*. +- Exporters vary (e.g., present_key_values.X.key vs present.X.k). A small name-normalization layer is common and recommended. + +### Typical tensor shapes (may vary by export) + +- Input IDs: [batch, cur_len] (cur_len is often 1 during decoding). +- Keys/Values per layer: + - Key: [batch, num_kv_heads, kv_len, head_dim] + - Value: [batch, num_kv_heads, kv_len, head_dim] (sometimes the last two dims are swapped) +- kv_len increases with the number of processed tokens. +- With grouped-query attention (GQA), num_kv_heads < num_attention_heads; queries fan out over fewer KV heads. +- Attention mask: could be [batch, total_len] or a 4D causal mask; confirm the export. +- Position IDs: usually [batch, cur_len], incrementing with the sequence; sometimes implicit. + +Always check your model’s input/output metadata to confirm exact shapes and names. + +### Memory considerations (order-of-magnitude) + +KV memory (fp16) ≈ 2 (K and V) × layers × batch × num_kv_heads × head_dim × seq_len × 2 bytes. +- Example: 32 layers, batch 1, 8 KV heads, head_dim 128, seq_len 4096 → ~537 MB. +- Multiply by concurrent sequences to estimate server memory. +- Practical strategies: + - Use fp16/bf16; consider 8-bit KV cache if supported. + - Use paged attention to allocate KV in fixed-size pages, enabling efficient batching and prefix sharing. + - Implement eviction (LRU/TTL) and caps per tenant. + +### Serving patterns + +- Single-process decoding loop (stateful): + - Keep present_key_values from step t; feed them as past_key_values at step t+1. + - Maintain this per active generation (session/conversation/beam). + +- Multi-user server: + - Maintain a KV cache handle per active sequence. Associate each client’s “continue” request with its handle. + - Keep the cache on the same device as the model (GPU/CPU). Avoid serializing to disk; device-specific and large. + - Use a scheduler to batch multiple sequences at the same decoding step; manage variable lengths with masks. + - Reclaim KV memory when a sequence ends or times out. + - For beam search: either duplicate caches per beam or use copy-on-write/page sharing for common prefixes. + +- Stateless API shape: + - The service returns an opaque handle after prefill. Clients send handle + new text/tokens to continue. The server resolves the handle to in-memory KV blocks. + +### Pseudocode for generation with KV cache + +- Prefill: + - inputs: input_ids = prompt; outputs: logits, present_kv + - pick next_token from logits +- Loop: + - inputs: input_ids = [next_token], past_kv = present_kv; outputs: logits, present_kv + - pick next_token; repeat + +### Common pitfalls and how to avoid them + +- Name mismatches (present_* vs past_*): add a mapping layer to normalize. +- Value tensor layout mismatch (kv_len and head_dim swapped in V): verify and transpose if needed. +- Incorrect/omitted position_ids or attention_mask: follow the export’s expectations. +- Moving KV across devices/processes: impractical; keep it co-located with the model runtime. +- Memory blow-ups: cap max concurrent sequences, use paging, and evict aggressively. + +### Quick checklist + +- At t=0: run prompt without past_kv; capture present_kv. +- At t>0: run with input_ids=[last token], past_kv=previous present_kv. +- Keep KV per session on the model device. +- Normalize naming present_* → past_*. +- Mind shapes/masks/positions and memory limits. + +By following this pattern, you “serve” the KV cache correctly and get fast, responsive generation by reusing attention state rather than recomputing it each step. \ No newline at end of file diff --git a/OrtForge.AI.Agent/Generation/InferenceConfig.cs b/OrtForge.AI.Agent/Generation/InferenceConfig.cs new file mode 100644 index 0000000..6ba37ff --- /dev/null +++ b/OrtForge.AI.Agent/Generation/InferenceConfig.cs @@ -0,0 +1,56 @@ +namespace OrtForge.AI.Agent.Generation; + +public sealed record InferenceConfig +{ + public double Temperature { get; init; } = 0.7; + public int TopK { get; init; } = 40; + public double TopP { get; init; } = 0.95; + public double RepetitionPenalty { get; init; } = 1.0; + public double FrequencyPenalty { get; init; } = 0.0; + public double PresencePenalty { get; init; } = 0.0; + public int MaxTokens { get; init; } = 2048; + public int? Seed { get; init; } + public bool UseGreedy { get; init; } = false; + public double MinP { get; init; } = 0.0; + public double TfsZ { get; init; } = 1.0; + public double TypicalP { get; init; } = 1.0; + public HashSet StopTokenIds { get; init; } = []; // Model-specific, set by LlamaOptimizations + public string[] StopSequences { get; init; } = []; + + public static InferenceConfig Default => new() + { + Temperature = 0.5, + TopK = 40, + TopP = 0.95, + RepetitionPenalty = 1.1, + FrequencyPenalty = 0.1, + PresencePenalty = 0.1 + }; + + public static InferenceConfig Greedy => new() + { + UseGreedy = true, + Temperature = 0.0, + RepetitionPenalty = 1.05 // Even for greedy, prevent repetition + }; + + public static InferenceConfig Creative => new() + { + Temperature = 0.8, + TopK = 50, + TopP = 0.9, + RepetitionPenalty = 1.15, + FrequencyPenalty = 0.2, + PresencePenalty = 0.2 + }; + + public static InferenceConfig Precise => new() + { + Temperature = 0.3, + TopK = 20, + TopP = 0.8, + RepetitionPenalty = 1.1, + FrequencyPenalty = 0.15, + PresencePenalty = 0.1 + }; +} diff --git a/OrtForge.AI.Agent/Generation/Sampling.cs b/OrtForge.AI.Agent/Generation/Sampling.cs new file mode 100644 index 0000000..490355a --- /dev/null +++ b/OrtForge.AI.Agent/Generation/Sampling.cs @@ -0,0 +1,345 @@ +namespace OrtForge.AI.Agent.Generation; + +public static class Sampling +{ + public static int Sample(ReadOnlySpan logits, InferenceConfig config, List? previousTokens = null, Random? rng = null) + { + rng ??= config.Seed.HasValue ? new Random(config.Seed.Value) : Random.Shared; + + if (config.UseGreedy || config.Temperature <= 1e-6) + { + return Greedy(logits); + } + + var logitsArray = logits.ToArray(); + + if (previousTokens is { Count: > 0 }) + { + if (config.RepetitionPenalty > 1.0) + { + ApplyRepetitionPenalty(logitsArray, previousTokens, config.RepetitionPenalty); + } + + if (config.FrequencyPenalty > 0.0) + { + ApplyFrequencyPenalty(logitsArray, previousTokens, config.FrequencyPenalty); + } + + if (config.PresencePenalty > 0.0) + { + ApplyPresencePenalty(logitsArray, previousTokens, config.PresencePenalty); + } + } + + var probs = Softmax(logitsArray, config.Temperature); + + if (config.MinP > 0.0) + { + ApplyMinP(probs, config.MinP); + } + + if (config.TopK > 0) + { + ApplyTopK(probs, config.TopK); + } + + if (config.TopP < 1.0) + { + ApplyTopP(probs, config.TopP); + } + + if (config.TfsZ < 1.0) + { + ApplyTailFreeSampling(probs, config.TfsZ); + } + + if (config.TypicalP < 1.0) + { + ApplyTypicalSampling(probs, config.TypicalP); + } + + return SampleCategorical(probs, rng); + } + + public static int Greedy(ReadOnlySpan logits) + { + var maxIdx = 0; + var maxVal = float.NegativeInfinity; + for (int i = 0; i < logits.Length; i++) + { + if (logits[i] > maxVal) { maxVal = logits[i]; maxIdx = i; } + } + return maxIdx; + } + + private static double[] Softmax(float[] logits, double temperature) + { + var probs = new double[logits.Length]; + var maxLogit = logits.Max(); + double sum = 0; + + for (int i = 0; i < logits.Length; i++) + { + var scaled = (logits[i] - maxLogit) / Math.Max(1e-6, temperature); + probs[i] = Math.Exp(scaled); + sum += probs[i]; + } + + for (int i = 0; i < probs.Length; i++) + { + probs[i] /= sum; + } + + return probs; + } + + private static void ApplyRepetitionPenalty(float[] logits, List previousTokens, double penalty) + { + // Skip if penalty is 1.0 or less (1.0 is no-op, <= 0 is invalid) + if (penalty <= 1.0) + return; + + var tokenCounts = new Dictionary(); + foreach (var token in previousTokens) + { + tokenCounts[token] = tokenCounts.GetValueOrDefault(token, 0) + 1; + } + + foreach (var (token, count) in tokenCounts) + { + if (token >= 0 && token < logits.Length) + { + var penaltyFactor = Math.Pow(penalty, count); + if (logits[token] > 0) + { + logits[token] /= (float)penaltyFactor; + } + else + { + logits[token] *= (float)penaltyFactor; + } + } + } + } + + private static void ApplyFrequencyPenalty(float[] logits, List previousTokens, double penalty) + { + if (penalty <= 0) + return; + + var tokenCounts = new Dictionary(); + foreach (var token in previousTokens) + { + tokenCounts[token] = tokenCounts.GetValueOrDefault(token, 0) + 1; + } + + foreach (var (token, count) in tokenCounts) + { + if (token >= 0 && token < logits.Length) + { + logits[token] -= (float)(count * penalty); + } + } + } + + private static void ApplyPresencePenalty(float[] logits, List previousTokens, double penalty) + { + if (penalty <= 0) + return; + + var presentTokens = new HashSet(); + foreach (var token in previousTokens) + { + presentTokens.Add(token); + } + + foreach (var token in presentTokens) + { + if (token >= 0 && token < logits.Length) + { + logits[token] -= (float)penalty; + } + } + } + + private static void ApplyMinP(double[] probs, double minP) + { + var maxProb = probs.Max(); + var threshold = maxProb * minP; + + for (int i = 0; i < probs.Length; i++) + { + if (probs[i] < threshold) + { + probs[i] = 0.0; + } + } + + var sum = probs.Sum(); + if (sum > 0) + { + for (int i = 0; i < probs.Length; i++) + { + probs[i] /= sum; + } + } + } + + private static void ApplyTopK(double[] probs, int k) + { + if (k <= 0 || k >= probs.Length) return; + + var indices = Enumerable.Range(0, probs.Length).ToArray(); + Array.Sort(indices, (a, b) => probs[b].CompareTo(probs[a])); + + for (int i = k; i < indices.Length; i++) + { + probs[indices[i]] = 0.0; + } + + var sum = probs.Sum(); + if (sum > 0) + { + for (int i = 0; i < probs.Length; i++) + { + probs[i] /= sum; + } + } + } + + private static void ApplyTopP(double[] probs, double p) + { + if (p >= 1.0) return; + + var indices = Enumerable.Range(0, probs.Length).ToArray(); + Array.Sort(indices, (a, b) => probs[b].CompareTo(probs[a])); + + double cumulative = 0.0; + int cutoff = probs.Length; + + for (int i = 0; i < indices.Length; i++) + { + cumulative += probs[indices[i]]; + if (cumulative >= p) + { + cutoff = i + 1; + break; + } + } + + for (int i = cutoff; i < indices.Length; i++) + { + probs[indices[i]] = 0.0; + } + + var sum = probs.Sum(); + if (sum > 0) + { + for (int i = 0; i < probs.Length; i++) + { + probs[i] /= sum; + } + } + } + + private static void ApplyTailFreeSampling(double[] probs, double z) + { + if (z >= 1.0) return; + + var indices = Enumerable.Range(0, probs.Length).ToArray(); + Array.Sort(indices, (a, b) => probs[b].CompareTo(probs[a])); + + var derivatives = new double[probs.Length - 1]; + for (int i = 0; i < derivatives.Length; i++) + { + derivatives[i] = Math.Abs(probs[indices[i]] - probs[indices[i + 1]]); + } + + var normDerivatives = derivatives.Select(d => d / derivatives.Sum()).ToArray(); + + double cumulative = 0.0; + int cutoff = probs.Length; + + for (int i = 0; i < normDerivatives.Length; i++) + { + cumulative += normDerivatives[i]; + if (cumulative >= z) + { + cutoff = i + 1; + break; + } + } + + for (int i = cutoff; i < indices.Length; i++) + { + probs[indices[i]] = 0.0; + } + + var sum = probs.Sum(); + if (sum > 0) + { + for (int i = 0; i < probs.Length; i++) + { + probs[i] /= sum; + } + } + } + + private static void ApplyTypicalSampling(double[] probs, double p) + { + if (p >= 1.0) return; + + var entropy = -probs.Where(x => x > 0).Sum(x => x * Math.Log(x)); + var surprisals = probs.Select(x => x > 0 ? -Math.Log(x) : double.PositiveInfinity).ToArray(); + var deviations = surprisals.Select(s => Math.Abs(s - entropy)).ToArray(); + + var indices = Enumerable.Range(0, probs.Length).ToArray(); + Array.Sort(indices, (a, b) => deviations[a].CompareTo(deviations[b])); + + double cumulative = 0.0; + int cutoff = 0; + + for (int i = 0; i < indices.Length; i++) + { + if (probs[indices[i]] > 0) + { + cumulative += probs[indices[i]]; + if (cumulative >= p) + { + cutoff = i + 1; + break; + } + } + } + + for (int i = cutoff; i < indices.Length; i++) + { + probs[indices[i]] = 0.0; + } + + var sum = probs.Sum(); + if (sum > 0) + { + for (int i = 0; i < probs.Length; i++) + { + probs[i] /= sum; + } + } + } + + private static int SampleCategorical(double[] probs, Random rng) + { + var r = rng.NextDouble(); + double cumulative = 0.0; + + for (int i = 0; i < probs.Length; i++) + { + cumulative += probs[i]; + if (r <= cumulative) return i; + } + + return probs.Length - 1; + } +} + + diff --git a/OrtForge.AI.Agent/Generation/TokenHistory.cs b/OrtForge.AI.Agent/Generation/TokenHistory.cs new file mode 100644 index 0000000..6449bfd --- /dev/null +++ b/OrtForge.AI.Agent/Generation/TokenHistory.cs @@ -0,0 +1,48 @@ +namespace OrtForge.AI.Agent.Generation; + +/// +/// Maintains a sliding window of recent tokens for repetition penalty purposes. +/// This allows repetition penalties to be applied across conversation turns. +/// +public sealed class TokenHistory +{ + private readonly Queue _tokens = new(); + + public TokenHistory(int maxSize = 128) + { + if (maxSize <= 0) + throw new ArgumentException("Max size must be positive", nameof(maxSize)); + MaxSize = maxSize; + } + + public int MaxSize { get; } + public int Count => _tokens.Count; + + public void AddToken(int token) + { + _tokens.Enqueue(token); + while (_tokens.Count > MaxSize) + { + _tokens.Dequeue(); + } + } + + public void AddTokens(IEnumerable tokens) + { + foreach (var token in tokens) + { + AddToken(token); + } + } + + public List GetTokens() + { + return _tokens.ToList(); + } + + public void Clear() + { + _tokens.Clear(); + } +} + diff --git a/OrtForge.AI.Agent/LLM/KvState.cs b/OrtForge.AI.Agent/LLM/KvState.cs new file mode 100644 index 0000000..2eb97de --- /dev/null +++ b/OrtForge.AI.Agent/LLM/KvState.cs @@ -0,0 +1,54 @@ +namespace OrtForge.AI.Agent.LLM; + +/// +/// Centralized KV cache state with authoritative sequence length management. +/// This is the single source of truth for sequence length tracking. +/// +public sealed class KvState : IDisposable +{ + public List Tensors { get; } + private int _accumulatedSequenceLength; + + /// + /// The authoritative sequence length - total tokens processed so far. + /// This is the single source of truth for all sequence length calculations. + /// + public int AccumulatedSequenceLength + { + get => _accumulatedSequenceLength; + private set + { + if (value < 0) + throw new ArgumentException("Sequence length cannot be negative", nameof(value)); + _accumulatedSequenceLength = value; + } + } + + public KvState(List mappedOutputs, int initialSequenceLength = 0) + { + AccumulatedSequenceLength = initialSequenceLength; + Tensors = mappedOutputs; + } + + /// + /// Calculate the total sequence length after adding new tokens + /// + /// Number of new tokens to add + /// The total sequence length after adding new tokens + public int CalculateTotalLengthAfterTokens(int newTokenCount) + { + if (newTokenCount < 0) + throw new ArgumentException("New token count cannot be negative", nameof(newTokenCount)); + return AccumulatedSequenceLength + newTokenCount; + } + + public void Dispose() + { + foreach (var tensor in Tensors) + { + tensor.Tensor.Dispose(); + } + + Tensors.Clear(); + } +} \ No newline at end of file diff --git a/OrtForge.AI.Agent/LLM/KvTensorMappingStrategy.cs b/OrtForge.AI.Agent/LLM/KvTensorMappingStrategy.cs new file mode 100644 index 0000000..2944a3d --- /dev/null +++ b/OrtForge.AI.Agent/LLM/KvTensorMappingStrategy.cs @@ -0,0 +1,62 @@ +using System.Text.RegularExpressions; + +namespace OrtForge.AI.Agent.LLM; + +public class KvTensorMappingStrategy +{ + private static readonly Regex InputRegex = new("^past.*?([0-9]+)(.*)$", RegexOptions.Compiled); + private static readonly Regex OutputRegex = new("^present.*?([0-9]+)(.*)$", RegexOptions.Compiled); + + private readonly Dictionary _inputMappingCache = new(); + private readonly Dictionary _outpuMappingCache = new(); + + public bool IsKvInput(string name) + { + return _inputMappingCache.ContainsKey(name); + } + + public bool IsKvOutput(string name) + { + return _outpuMappingCache.ContainsKey(name); + } + + public static KvTensorMappingStrategy Create(IEnumerable inputMetadata, IEnumerable outputMetadata) + { + var outputSet = outputMetadata.ToHashSet(); + + var result = new KvTensorMappingStrategy(); + + var inputs = new Dictionary<(int, string), string>(); + + foreach (var input in inputMetadata) + { + var match = InputRegex.Match(input); + if (match.Success) + { + inputs[(int.Parse(match.Groups[1].Value), match.Groups[2].Value)] = input; + } + } + + foreach (var output in outputSet) + { + var match = OutputRegex.Match(output); + if (match.Success) + { + var outputIndex = int.Parse(match.Groups[1].Value); + var outputName = match.Groups[2].Value; + if (inputs.TryGetValue((outputIndex, outputName), out var input)) + { + result._inputMappingCache[input] = output; + result._outpuMappingCache[output] = input; + } + } + } + + return result; + } + + public string MapOutputToInput(string output) + { + return _outpuMappingCache.GetValueOrDefault(output) ?? throw new InvalidOperationException($"Cannot map output tensor '{output}'");; + } +} diff --git a/OrtForge.AI.Agent/LLM/LlamaOptimizations.cs b/OrtForge.AI.Agent/LLM/LlamaOptimizations.cs new file mode 100644 index 0000000..7c511bb --- /dev/null +++ b/OrtForge.AI.Agent/LLM/LlamaOptimizations.cs @@ -0,0 +1,77 @@ +using OrtForge.AI.Agent.Generation; + +namespace OrtForge.AI.Agent.LLM; + +public static class LlamaOptimizations +{ + public static readonly Dictionary ModelStopTokens = new() + { + [ModelType.Llama3_1] = [128001, 128009], + [ModelType.Llama3_2] = [128001, 128009], + [ModelType.Llama3] = [128001, 128009], + [ModelType.Llama2] = [2], + [ModelType.Default] = [0, 2] + }; + + public static readonly Dictionary ModelStopSequences = new() + { + [ModelType.Llama3_1] = ["<|eot_id|>", "<|end_of_text|>"], + [ModelType.Llama3_2] = ["<|eot_id|>", "<|end_of_text|>"], + [ModelType.Llama3] = ["<|eot_id|>", "<|end_of_text|>"], + [ModelType.Llama2] = [""], + [ModelType.Default] = [] + }; + + public static InferenceConfig GetOptimalConfigForModel(ModelType modelType, InferenceConfig? baseConfig = null) + { + baseConfig ??= InferenceConfig.Default; + + var stopTokenIds = ModelStopTokens.GetValueOrDefault(modelType, ModelStopTokens[ModelType.Default]); + var stopSequences = ModelStopSequences.GetValueOrDefault(modelType, ModelStopSequences[ModelType.Default]); + + // Use model-specific stop tokens, only add base config tokens if they're non-empty and valid + var mergedStopTokens = new HashSet(stopTokenIds); + foreach (var token in baseConfig.StopTokenIds) + { + // Only add tokens > 127999 (special tokens range for Llama 3) or explicitly set + if (token >= 128000) + { + mergedStopTokens.Add(token); + } + } + + return baseConfig with + { + StopTokenIds = mergedStopTokens, + StopSequences = stopSequences.Concat(baseConfig.StopSequences).Distinct().ToArray(), + Temperature = modelType.IsLlama3Family() ? Math.Max(0.1, baseConfig.Temperature) : baseConfig.Temperature, + TopP = modelType.IsLlama3Family() ? Math.Min(0.95, baseConfig.TopP) : baseConfig.TopP + }; + } + + /// + /// Creates position IDs for the current inference step. + /// + /// Total sequence length after adding new tokens + /// Number of new tokens being added + /// Position IDs array of length newTokenCount + public static long[] CreateOptimalPositionIds(int totalSequenceLength, int newTokenCount) + { + // Position IDs should be [startPos, startPos+1, ..., startPos+newTokenCount-1] + // where startPos = totalSequenceLength - newTokenCount + var startPosition = totalSequenceLength - newTokenCount; + var positionIds = new long[newTokenCount]; + for (int i = 0; i < newTokenCount; i++) + { + positionIds[i] = startPosition + i; + } + return positionIds; + } + + public static long[]? CreateOptimalAttentionMask(int totalSequenceLength) + { + var attentionMask = new long[totalSequenceLength]; + Array.Fill(attentionMask, 1L); + return attentionMask; + } +} diff --git a/OrtForge.AI.Agent/LLM/LlamaSession.cs b/OrtForge.AI.Agent/LLM/LlamaSession.cs new file mode 100644 index 0000000..8b8bce7 --- /dev/null +++ b/OrtForge.AI.Agent/LLM/LlamaSession.cs @@ -0,0 +1,372 @@ +using System.Runtime.InteropServices; +using Microsoft.ML.OnnxRuntime; +using Microsoft.ML.OnnxRuntime.Tensors; +using OrtForge.AI.Agent.Generation; + +namespace OrtForge.AI.Agent.LLM; + +public sealed class LlamaSession : IDisposable +{ + private readonly InferenceSession _session; + private readonly KvTensorMappingStrategy _kvMapping; + private string[] _outputNames = []; + private string[] _inputNames = []; + private readonly Dictionary _kvOutputs = new(); + private readonly Dictionary _kvInputs = new(); + + public LlamaSession(InferenceSession session, ModelType modelType = ModelType.Default) + { + _session = session; + ModelType = modelType; + + OptimalConfig = LlamaOptimizations.GetOptimalConfigForModel(modelType); + + _kvMapping = KvTensorMappingStrategy.Create(_session.InputMetadata.Keys, _session.OutputMetadata.Keys); + + DiscoverModelInputsAndOutputs(); + } + + public ModelType ModelType { get; } + public InferenceConfig OptimalConfig { get; } + + public void MapInputs(StepInputs inputs, OrtValue[] modelInputs) + { + var inputShape = inputs.InputIds.GetTensorTypeAndShape().Shape; + var batchSize = inputShape[0]; + + // All required inputs must be provided to avoid memory leaks from untracked OrtValues + if (inputs.PositionIds == null) + throw new ArgumentException("PositionIds must be provided", nameof(inputs)); + if (inputs.AttentionMask == null) + throw new ArgumentException("AttentionMask must be provided", nameof(inputs)); + + modelInputs[0] = inputs.InputIds; + modelInputs[1] = inputs.PositionIds; + modelInputs[2] = inputs.AttentionMask; + + if (inputs.Kv.Tensors.Count > 0) + { + foreach (var kv in inputs.Kv.Tensors) + { + modelInputs[kv.Info.Offset] = kv.Tensor; + } + } + else + { + foreach (var kv in _kvInputs.Values) + { + kv.Dimensions[0] = batchSize; + kv.Dimensions[2] = 0L; + modelInputs[kv.Offset] = + OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, kv.ElementType, kv.Dimensions); + } + } + } + + public async Task RunStepAsync(StepInputs inputs, CancellationToken cancellationToken = default) + { + var inputShape = inputs.InputIds.GetTensorTypeAndShape().Shape; + var batchSize = inputShape[0]; + var currentInputLength = inputShape[1]; + + var inputValues = new OrtValue[_inputNames.Length]; + var outputValues = new OrtValue[_outputNames.Length]; + + MapInputs(inputs, inputValues); + var stepOutputs = MapOutputs(inputs, outputValues, batchSize, currentInputLength); + + cancellationToken.ThrowIfCancellationRequested(); + + try + { + using var runOptions = new RunOptions(); + await _session.RunAsync(runOptions, _inputNames, inputValues, _outputNames, outputValues); + } + catch (Exception ex) + { + stepOutputs.Dispose(); + throw new InvalidOperationException($"Error running the model: {ex.Message}", ex); + } + + return stepOutputs; + } + + private StepOutputs MapOutputs(StepInputs inputs, + OrtValue[] outputValues, long batchSize, long currentInputLength) + { + var logitsMeta = _session.OutputMetadata["logits"]; + var vocabSize = logitsMeta.Dimensions[^1]; + var logitsTensor = OrtValue.CreateAllocatedTensorValue( + OrtAllocator.DefaultInstance, + logitsMeta.ElementDataType, + [batchSize, currentInputLength, vocabSize]); + + var totalSequenceLength = inputs.Kv.CalculateTotalLengthAfterTokens((int)currentInputLength); + List mappedKvTensors = []; + var newKv = new KvState(mappedKvTensors, totalSequenceLength); + var outputs = new StepOutputs(logitsTensor, newKv); + + outputValues[0] = logitsTensor; + foreach (var output in _kvOutputs.Values) + { + var kvDims = output.Dimensions.Select(d => (long)d).ToArray(); + kvDims[0] = batchSize; + if (inputs.Kv.Tensors.Count == 0) + { + kvDims[2] = currentInputLength; + } + else + { + kvDims[2] = totalSequenceLength; + } + + var kvTensor = OrtValue.CreateAllocatedTensorValue( + OrtAllocator.DefaultInstance, + output.ElementType, + kvDims); + outputValues[output.Offset] = kvTensor; + mappedKvTensors.Add(new OutputKvTensor + { + Tensor = kvTensor, + Info = _kvInputs[_kvMapping.MapOutputToInput(output.Name)] + }); + } + + return outputs; + } + + private void DiscoverModelInputsAndOutputs() + { + var inputMetadata = _session.InputMetadata; + var outputMetadata = _session.OutputMetadata; + + if (!inputMetadata.ContainsKey("input_ids")) + throw new InvalidOperationException("Model has to have 'input_ids'."); + + if (!inputMetadata.ContainsKey("position_ids")) + throw new InvalidOperationException("Model has to have 'position_ids'."); + + if (!inputMetadata.ContainsKey("attention_mask")) + throw new InvalidOperationException("Model has to have 'attention_mask'."); + + if (!outputMetadata.ContainsKey("logits")) + throw new InvalidOperationException("Model has to have 'logits' as its output."); + + var inputNames = new List + { + "input_ids", + "position_ids", + "attention_mask" + }; + + var inputOffset = 3; + foreach (var inputName in inputMetadata.Keys) + { + if (_kvMapping.IsKvInput(inputName)) + { + var inputMeta = inputMetadata[inputName]; + var dimensions = inputMeta.Dimensions.Select(d => (long)d).ToArray(); + _kvInputs.Add(inputName, new KvTensorInfo + { + Name = inputName, + Dimensions = dimensions, + ElementType = inputMeta.ElementDataType, + Offset = inputOffset + }); + inputOffset++; + inputNames.Add(inputName); + } + } + + _inputNames = inputNames.ToArray(); + + var outputNames = new List { "logits" }; + + var outputOffset = 1; + + foreach (var outputName in outputMetadata.Keys) + { + if (_kvMapping.IsKvOutput(outputName)) + { + var outputMeta = outputMetadata[outputName]; + var dimensions = outputMeta.Dimensions.Select(d => (long)d).ToArray(); + _kvOutputs.Add(outputName, new KvTensorInfo + { + Name = outputName, + Dimensions = dimensions, + ElementType = outputMeta.ElementDataType, + Offset = outputOffset + }); + outputOffset++; + outputNames.Add(outputName); + } + } + + _outputNames = outputNames.ToArray(); + } + + public async Task RunOptimizedStepAsync(long[] inputIds, KvState kv, int totalSequenceLength, CancellationToken cancellationToken = default) + { + var positionIds = LlamaOptimizations.CreateOptimalPositionIds(totalSequenceLength, inputIds.Length); + var attentionMask = LlamaOptimizations.CreateOptimalAttentionMask(totalSequenceLength); + + using var inputs = StepInputs.Create(inputIds, kv, positionIds, attentionMask); + return await RunStepAsync(inputs, cancellationToken); + } + + public void Dispose() + { + _session.Dispose(); + } + + private static TensorElementType GetTensorElementType(Type type) + { + if (type == typeof(float)) return TensorElementType.Float; + if (type == typeof(Half)) return TensorElementType.Float16; + if (type.Name == "Float16" || type.FullName?.Contains("OnnxRuntime.Float16") == true) + return TensorElementType.Float16; + if (type == typeof(byte)) return TensorElementType.UInt8; + if (type == typeof(sbyte)) return TensorElementType.Int8; + if (type == typeof(int)) return TensorElementType.Int32; + if (type == typeof(long)) return TensorElementType.Int64; + return TensorElementType.Float; + } + + public sealed record StepInputs( + OrtValue InputIds, + KvState Kv, + OrtValue? PositionIds, + OrtValue? AttentionMask) : IDisposable + { + public void Dispose() + { + InputIds.Dispose(); + PositionIds?.Dispose(); + AttentionMask?.Dispose(); + Kv.Dispose(); + } + + public static StepInputs Create( + long[] inputIds, + KvState kv, + long[]? positionIds = null, + long[]? attentionMask = null) + { + OrtValue? inputIdsOrt = null; + OrtValue? positionIdsOrt = null; + OrtValue? attentionMaskOrt = null; + + try + { + inputIdsOrt = OrtValue.CreateTensorValueFromMemory( + inputIds, + [1, inputIds.Length]); + + if (positionIds != null) + { + positionIdsOrt = OrtValue.CreateTensorValueFromMemory( + positionIds, + [1, positionIds.Length]); + } + + if (attentionMask != null) + { + attentionMaskOrt = OrtValue.CreateTensorValueFromMemory( + attentionMask, + [1, attentionMask.Length]); + } + + return new StepInputs(inputIdsOrt, kv, positionIdsOrt, attentionMaskOrt); + } + catch + { + // Dispose already-created OrtValues on exception to prevent memory leak + inputIdsOrt?.Dispose(); + positionIdsOrt?.Dispose(); + attentionMaskOrt?.Dispose(); + throw; + } + } + } + + public sealed record StepOutputs( + OrtValue Logits, + KvState KvCache) : IDisposable + { + public void Dispose() + { + Logits.Dispose(); + } + + public Span GetLogitsSpan() + { + var typeInfo = Logits.GetTensorTypeAndShape(); + switch (typeInfo.ElementDataType) + { + case TensorElementType.Float: + return Logits.GetTensorMutableDataAsSpan(); + + case TensorElementType.Float16: + case TensorElementType.BFloat16: + return GetLogitsArray().AsSpan(); + + default: + throw new NotSupportedException($"Unsupported tensor element type: {typeInfo.ElementDataType}"); + } + } + + public float[] GetLogitsArray() + { + var typeInfo = Logits.GetTensorTypeAndShape(); + switch (typeInfo.ElementDataType) + { + case TensorElementType.Float: + { + var span = Logits.GetTensorMutableDataAsSpan(); + var array = new float[span.Length]; + span.CopyTo(array); + return array; + } + case TensorElementType.Float16: + { + var byteSpan = Logits.GetTensorMutableDataAsSpan(); + var halfSpan = MemoryMarshal.Cast(byteSpan); + var array = GC.AllocateUninitializedArray(halfSpan.Length); + for (int i = 0; i < halfSpan.Length; i++) + { + array[i] = (float)halfSpan[i]; + } + + return array; + } + case TensorElementType.BFloat16: + { + var byteSpan = Logits.GetTensorMutableDataAsSpan(); + var bfloatSpan = MemoryMarshal.Cast(byteSpan); + var array = GC.AllocateUninitializedArray(bfloatSpan.Length); + for (int i = 0; i < bfloatSpan.Length; i++) + { + array[i] = (float)bfloatSpan[i]; + } + return array; + } + default: + throw new NotSupportedException($"Unsupported tensor element type: {typeInfo.ElementDataType}"); + } + } + } + + public sealed class OutputKvTensor + { + public required KvTensorInfo Info { get; init; } + public required OrtValue Tensor { get; set; } + } + + public sealed class KvTensorInfo + { + public required string Name { get; init; } + public TensorElementType ElementType { get; init; } + public required long[] Dimensions { get; init; } + public int Offset { get; init; } + } +} \ No newline at end of file diff --git a/OrtForge.AI.Agent/LLM/ModelType.cs b/OrtForge.AI.Agent/LLM/ModelType.cs new file mode 100644 index 0000000..33dc230 --- /dev/null +++ b/OrtForge.AI.Agent/LLM/ModelType.cs @@ -0,0 +1,80 @@ +namespace OrtForge.AI.Agent.LLM; + +/// +/// Supported LLM model types with optimized configurations +/// +public enum ModelType +{ + /// + /// Default/Unknown model type with basic configuration + /// + Default = 0, + + /// + /// Llama 2 model family + /// + Llama2 = 1, + + /// + /// Llama 3 base model + /// + Llama3 = 2, + + /// + /// Llama 3.1 model + /// + Llama3_1 = 3, + + /// + /// Llama 3.2 model + /// + Llama3_2 = 4 +} + +/// +/// Extension methods for ModelType enum +/// +public static class ModelTypeExtensions +{ + /// + /// Check if the model is part of the Llama 3 family + /// + public static bool IsLlama3Family(this ModelType modelType) + { + return modelType is ModelType.Llama3 or ModelType.Llama3_1 or ModelType.Llama3_2; + } + + /// + /// Get the string representation for backwards compatibility + /// + public static string ToModelKey(this ModelType modelType) + { + return modelType switch + { + ModelType.Llama2 => "llama-2", + ModelType.Llama3 => "llama-3", + ModelType.Llama3_1 => "llama-3.1", + ModelType.Llama3_2 => "llama-3.2", + _ => "default" + }; + } + + /// + /// Parse model type from string (for backwards compatibility and auto-detection) + /// + public static ModelType ParseFromString(string modelName) + { + var lower = modelName.ToLowerInvariant(); + + if (lower.Contains("llama-3.2") || lower.Contains("llama3.2")) + return ModelType.Llama3_2; + if (lower.Contains("llama-3.1") || lower.Contains("llama3.1")) + return ModelType.Llama3_1; + if (lower.Contains("llama-3") || lower.Contains("llama3")) + return ModelType.Llama3; + if (lower.Contains("llama-2") || lower.Contains("llama2")) + return ModelType.Llama2; + + return ModelType.Default; + } +} diff --git a/OrtForge.AI.Agent/OrtForge.AI.Agent.csproj b/OrtForge.AI.Agent/OrtForge.AI.Agent.csproj new file mode 100644 index 0000000..fc6bdf3 --- /dev/null +++ b/OrtForge.AI.Agent/OrtForge.AI.Agent.csproj @@ -0,0 +1,20 @@ + + + net8.0 + enable + enable + latest + + + + + + + + + + + + + + diff --git a/OrtForge.AI.Agent/Properties/AssemblyInfo.cs b/OrtForge.AI.Agent/Properties/AssemblyInfo.cs new file mode 100644 index 0000000..d51ddb8 --- /dev/null +++ b/OrtForge.AI.Agent/Properties/AssemblyInfo.cs @@ -0,0 +1,3 @@ +using System.Runtime.CompilerServices; + +[assembly: InternalsVisibleTo("OrtForge.AI.UnitTests")] diff --git a/OrtForge.AI.Agent/Rag/InMemoryVectorStore.cs b/OrtForge.AI.Agent/Rag/InMemoryVectorStore.cs new file mode 100644 index 0000000..6d5acab --- /dev/null +++ b/OrtForge.AI.Agent/Rag/InMemoryVectorStore.cs @@ -0,0 +1,42 @@ +namespace OrtForge.AI.Agent.Rag; + +public sealed class InMemoryVectorStore +{ + public sealed record Item(string Id, float[] Vector, string Text, IReadOnlyDictionary? Metadata); + + private readonly List _items = []; + + public void Upsert(Item item) + { + var idx = _items.FindIndex(x => x.Id == item.Id); + if (idx >= 0) _items[idx] = item; else _items.Add(item); + } + + public IReadOnlyList TopK(float[] query, int k = 5) + { + var qn = Normalize(query); + return _items + .Select(x => (item: x, score: Cosine(qn, Normalize(x.Vector)))) + .OrderByDescending(x => x.score) + .Take(k) + .Select(x => x.item) + .ToList(); + } + + private static float[] Normalize(float[] v) + { + double s = 0; for (int i = 0; i < v.Length; i++) s += (double)v[i] * v[i]; + var n = Math.Sqrt(Math.Max(s, 1e-9)); + var o = new float[v.Length]; + for (int i = 0; i < v.Length; i++) o[i] = (float)(v[i] / n); + return o; + } + + private static double Cosine(float[] a, float[] b) + { + double s = 0; for (int i = 0; i < a.Length; i++) s += (double)a[i] * b[i]; + return s; + } +} + + diff --git a/OrtForge.AI.Agent/Runtime/OrtRuntimeFactory.cs b/OrtForge.AI.Agent/Runtime/OrtRuntimeFactory.cs new file mode 100644 index 0000000..12c11de --- /dev/null +++ b/OrtForge.AI.Agent/Runtime/OrtRuntimeFactory.cs @@ -0,0 +1,29 @@ +using Microsoft.ML.OnnxRuntime; + +namespace OrtForge.AI.Agent.Runtime; + +public static class OrtRuntimeFactory +{ + private static readonly Lazy s_env = new(OrtEnv.Instance); + + public static OrtEnv Env => s_env.Value; + + public static InferenceSession CreateSession(string modelPath, SessionOptions? options = null) + { + var opts = options ?? CreateDefaultSessionOptions(); + return new InferenceSession(modelPath, opts); + } + + public static SessionOptions CreateDefaultSessionOptions() + { + var so = new SessionOptions(); + so.GraphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL; + so.ExecutionMode = ExecutionMode.ORT_SEQUENTIAL; + so.AppendExecutionProvider_MIGraphX(); + so.AppendExecutionProvider_CPU(); + so.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING; + return so; + } +} + + diff --git a/OrtForge.AI.Agent/Tokenization/HuggingFaceTokenizerWrapper.cs b/OrtForge.AI.Agent/Tokenization/HuggingFaceTokenizerWrapper.cs new file mode 100644 index 0000000..64aa2db --- /dev/null +++ b/OrtForge.AI.Agent/Tokenization/HuggingFaceTokenizerWrapper.cs @@ -0,0 +1,92 @@ +using System.Buffers; +using Microsoft.ML.Tokenizers; +using EncodedToken = Microsoft.ML.Tokenizers.EncodedToken; + +namespace OrtForge.AI.Agent.Tokenization; + +/// +/// A wrapper that adapts Hugging Face Tokenizers.DotNet to work with Microsoft.ML.Tokenizers interface +/// +public sealed class HuggingFaceTokenizerWrapper : Tokenizer +{ + private readonly Tokenizers.DotNet.Tokenizer _hfTokenizer; + + public HuggingFaceTokenizerWrapper(Tokenizers.DotNet.Tokenizer hfTokenizer) + { + _hfTokenizer = hfTokenizer ?? throw new ArgumentNullException(nameof(hfTokenizer)); + } + + // BOS token ID for Llama 3 models - the tokenizer auto-adds this + private const uint Llama3BosTokenId = 128000; + + //TODO: replace with Span able implementation + protected override EncodeResults EncodeToTokens(string? text, ReadOnlySpan textSpan, + EncodeSettings settings) + { + try + { + uint[] tokenIds; + if (text != null) + { + tokenIds = _hfTokenizer.Encode(text); + } + else + { + tokenIds = _hfTokenizer.Encode(new string(textSpan)); + } + + // Strip the auto-added BOS token if present + // The Tokenizers.DotNet library automatically adds BOS to every encoding, + // which corrupts prompts that already include <|begin_of_text|> + if (tokenIds.Length > 0 && tokenIds[0] == Llama3BosTokenId) + { + tokenIds = tokenIds[1..]; + } + + var encodedTokens = new List(tokenIds.Length); + foreach (var tid in tokenIds) + { + encodedTokens.Add(new EncodedToken((int)tid, string.Empty, default)); + } + + return new EncodeResults + { + CharsConsumed = text?.Length ?? textSpan.Length, + NormalizedText = null, + Tokens = encodedTokens + }; + } + catch (Exception ex) + { + throw new InvalidOperationException($"Failed to encode text: {ex.Message}", ex); + } + } + + //TODO: replace with proper implementation that works with ints + public override OperationStatus Decode(IEnumerable ids, Span destination, out int idsConsumed, + out int charsWritten) + { + try + { + var idArray = ids.Select(x => (uint)x).ToArray(); + var result = _hfTokenizer.Decode(idArray); + if (result.Length > destination.Length) + { + idsConsumed = 0; + charsWritten = 0; + return OperationStatus.DestinationTooSmall; + } + + idsConsumed = idArray.Length; + charsWritten = result.Length; + result.CopyTo(destination); + return OperationStatus.Done; + } + catch + { + idsConsumed = 0; + charsWritten = 0; + return OperationStatus.InvalidData; + } + } +} \ No newline at end of file diff --git a/OrtForge.AI.Agent/Tokenization/TokenizerService.cs b/OrtForge.AI.Agent/Tokenization/TokenizerService.cs new file mode 100644 index 0000000..0f33f0b --- /dev/null +++ b/OrtForge.AI.Agent/Tokenization/TokenizerService.cs @@ -0,0 +1,97 @@ +using Microsoft.ML.Tokenizers; +using HfTokenizer = Tokenizers.DotNet.Tokenizer; + +namespace OrtForge.AI.Agent.Tokenization; + +public sealed class TokenizerService +{ + private readonly Tokenizer _tokenizer; + + public TokenizerService(Tokenizer tokenizer) + { + _tokenizer = tokenizer; + } + + public static TokenizerService FromPretrained(string pathOrDir) + { + if (Directory.Exists(pathOrDir)) + { + var spmPath = Path.Combine(pathOrDir, "sentencepiece.bpe.model"); + using var fs = File.OpenRead(spmPath); + var tk = SentencePieceTokenizer.Create(fs); + return new TokenizerService(tk); + } + else + { + if (pathOrDir.EndsWith(".model", StringComparison.OrdinalIgnoreCase)) + { + using var fs = File.OpenRead(pathOrDir); + var tk = SentencePieceTokenizer.Create(fs); + return new TokenizerService(tk); + } + throw new ArgumentException("Unsupported tokenizer format", nameof(pathOrDir)); + } + } + + /// + /// Creates a TikToken-based tokenizer from a tokenizer.json file. + /// Note: This only works with OpenAI-compatible tokenizer formats, not Hugging Face BPE formats. + /// + public static TokenizerService FromTikToken(string filePath) + { + if (File.Exists(filePath)) + { + using var fs = File.OpenRead(filePath); + var tk = TiktokenTokenizer.Create(fs, null, null); + return new TokenizerService(tk); + } + else + { + throw new ArgumentException("File not found", nameof(filePath)); + } + } + + /// + /// Creates a Hugging Face tokenizer from a tokenizer.json file. + /// This supports BPE, WordPiece, and other Hugging Face tokenizer formats. + /// + public static TokenizerService FromHuggingFace(string tokenizerJsonPath) + { + if (!File.Exists(tokenizerJsonPath)) + { + throw new ArgumentException("Tokenizer file not found", nameof(tokenizerJsonPath)); + } + + try + { + var hfTokenizer = new HfTokenizer(tokenizerJsonPath); + var wrapper = new HuggingFaceTokenizerWrapper(hfTokenizer); + return new TokenizerService(wrapper); + } + catch (Exception ex) + { + throw new InvalidOperationException($"Failed to load Hugging Face tokenizer: {ex.Message}", ex); + } + } + + public int[] EncodeToIds(string text, bool addBos = true) + { + var tokens = _tokenizer.EncodeToTokens(text, out _); + var ids = tokens.Select(t => t.Id).ToArray(); + + // Tokenizer automatically adds BOS (128000). Skip it if not wanted. + if (!addBos && ids.Length > 0 && ids[0] == 128000) + { + return ids.Skip(1).ToArray(); + } + + return ids; + } + + public string DecodeFromIds(IReadOnlyList ids) + { + return _tokenizer.Decode(ids.ToArray()); + } +} + + diff --git a/OrtForge.AI.Agent/Tools/ToolInjectionManager.cs b/OrtForge.AI.Agent/Tools/ToolInjectionManager.cs new file mode 100644 index 0000000..d94203d --- /dev/null +++ b/OrtForge.AI.Agent/Tools/ToolInjectionManager.cs @@ -0,0 +1,174 @@ +using OrtForge.AI.Agent.LLM; +using OrtForge.AI.Agent.Tokenization; +using OrtForge.AI.Agent.Agents; + +namespace OrtForge.AI.Agent.Tools; + +/// +/// Result of a tool injection operation +/// +public record ToolInjectionResult( + bool Success, + string InjectedText, + int[] InjectedTokens, + KvState UpdatedKvState, + int NewSequenceLength, + string? ErrorMessage = null); + +/// +/// Validation result for KV state consistency +/// +public record KvStateValidationResult( + bool IsValid, + IReadOnlyList Issues); + +/// +/// Manages safe tool execution and result injection with KV state validation +/// +public sealed class ToolInjectionManager +{ + private readonly TokenizerService _tokenizer; + + public ToolInjectionManager(TokenizerService tokenizer) + { + _tokenizer = tokenizer ?? throw new ArgumentNullException(nameof(tokenizer)); + } + + /// + /// Execute tool and inject result with comprehensive validation + /// + public async Task ExecuteAndInjectAsync( + ToolCall toolCall, + Func toolExecutor, + ToolCallState toolState, + LlamaSession llamaSession, + KvState currentKvState, + int currentSequenceLength) + { + try + { + var preValidation = ValidateKvState(currentKvState, currentSequenceLength); + if (!preValidation.IsValid) + { + return new ToolInjectionResult( + false, "", [], currentKvState, currentSequenceLength, + $"Pre-injection KV state validation failed: {string.Join(", ", preValidation.Issues)}"); + } + + toolState.UpdateCallStatus(toolCall, ToolCallStatus.Executing); + + string result; + try + { + result = toolExecutor.Invoke(toolCall.Arguments); + toolState.UpdateCallStatus(toolCall, ToolCallStatus.Completed, result); + } + catch (Exception ex) + { + var errorMessage = $"Tool execution failed: {ex.Message}"; + toolState.UpdateCallStatus(toolCall, ToolCallStatus.Failed, error: errorMessage); + result = $"Error: {errorMessage}"; + } + + var injectedText = $"\n<|tool_result|>\n{result}\n<|/tool_result|>\n"; + var injectedTokens = _tokenizer.EncodeToIds(injectedText); + + var newSequenceLength = currentSequenceLength + injectedTokens.Length; + + var kvStateSnapshot = CreateKvStateSnapshot(currentKvState); + + var injectArray = injectedTokens.Select(token => (long)token).ToArray(); + var injectOutputs = await llamaSession.RunOptimizedStepAsync( + injectArray, currentKvState, newSequenceLength); + + var updatedKvState = injectOutputs.KvCache; + var postValidation = ValidateKvState(updatedKvState, newSequenceLength); + + if (!postValidation.IsValid) + { + injectOutputs.Dispose(); + Console.WriteLine("⚠️ Post-injection validation failed, attempting rollback"); + + return new ToolInjectionResult( + false, "", [], kvStateSnapshot, currentSequenceLength, + $"Post-injection KV state validation failed: {string.Join(", ", postValidation.Issues)}"); + } + + injectOutputs.Dispose(); + + Console.WriteLine($"✅ Tool injection successful: {toolCall.Name} → {injectedTokens.Length} tokens injected"); + + return new ToolInjectionResult( + true, injectedText, injectedTokens, updatedKvState, newSequenceLength); + } + catch (Exception ex) + { + Console.WriteLine($"❌ Tool injection failed with exception: {ex.Message}"); + return new ToolInjectionResult( + false, "", [], currentKvState, currentSequenceLength, + $"Tool injection exception: {ex.Message}"); + } + } + + /// + /// Validate KV state consistency and sequence length alignment + /// + public KvStateValidationResult ValidateKvState(KvState kvState, int expectedSequenceLength) + { + var issues = new List(); + + if (kvState.AccumulatedSequenceLength != expectedSequenceLength) + { + issues.Add($"Sequence length mismatch: KvState={kvState.AccumulatedSequenceLength}, Expected={expectedSequenceLength}"); + } + + var tensors = kvState.Tensors; + if (tensors.Count > 0) + { + try + { + foreach (var tensor in tensors) + { + var shape = tensor.Tensor.GetTensorTypeAndShape().Shape; + + if (shape.Length >= 3) // [batch, heads, seq_len, head_dim] + { + var tensorSeqLength = shape[2]; + if (tensorSeqLength != expectedSequenceLength) + { + issues.Add($"Tensor sequence dimension mismatch: tensor={tensorSeqLength}, expected={expectedSequenceLength}"); + } + } + } + } + catch (Exception ex) + { + issues.Add($"Error validating tensor shapes: {ex.Message}"); + } + } + + if (kvState.Tensors.Count == 0 && expectedSequenceLength > 0) + { + issues.Add("KV state has no tensors but sequence length > 0"); + } + + return new KvStateValidationResult(issues.Count == 0, issues); + } + + /// + /// Create a snapshot of KV state for potential rollback + /// Note: This is a reference snapshot - actual rollback would require deep copying + /// + private KvState CreateKvStateSnapshot(KvState originalKvState) + { + return originalKvState; + } + + /// + /// Validate that tool injection point is safe (at token boundary) + /// + public bool IsInjectionPointSafe(int currentStep, bool isGenerationPhase) + { + return isGenerationPhase; + } +} diff --git a/OrtForge.AI.MicroBenchmarks/BgeM3ModelBenchmarks.cs b/OrtForge.AI.MicroBenchmarks/BgeM3ModelBenchmarks.cs index b692584..66ccbd8 100644 --- a/OrtForge.AI.MicroBenchmarks/BgeM3ModelBenchmarks.cs +++ b/OrtForge.AI.MicroBenchmarks/BgeM3ModelBenchmarks.cs @@ -2,7 +2,7 @@ using BenchmarkDotNet.Engines; using Microsoft.ML.OnnxRuntime; using Microsoft.ML.OnnxRuntime.Tensors; -using OrtForge.AI.Models.Astractions; +using OrtForge.AI.Models.Abstractions; using OrtForge.AI.Models.Models; using OrtForge.AI.Models.Options; diff --git a/OrtForge.AI.MicroBenchmarks/BgeM3ModelConcurrentBenchmarks.cs b/OrtForge.AI.MicroBenchmarks/BgeM3ModelConcurrentBenchmarks.cs index 0540676..0aef3fc 100644 --- a/OrtForge.AI.MicroBenchmarks/BgeM3ModelConcurrentBenchmarks.cs +++ b/OrtForge.AI.MicroBenchmarks/BgeM3ModelConcurrentBenchmarks.cs @@ -4,7 +4,7 @@ using BenchmarkDotNet.Engines; using Microsoft.ML.OnnxRuntime; using Microsoft.ML.OnnxRuntime.Tensors; -using OrtForge.AI.Models.Astractions; +using OrtForge.AI.Models.Abstractions; using OrtForge.AI.Models.Models; using OrtForge.AI.Models.Options; diff --git a/OrtForge.AI.MicroBenchmarks/BgeRerankerM3ModelBenchmarks.cs b/OrtForge.AI.MicroBenchmarks/BgeRerankerM3ModelBenchmarks.cs index e04e3dd..69f0b19 100644 --- a/OrtForge.AI.MicroBenchmarks/BgeRerankerM3ModelBenchmarks.cs +++ b/OrtForge.AI.MicroBenchmarks/BgeRerankerM3ModelBenchmarks.cs @@ -2,7 +2,7 @@ using BenchmarkDotNet.Engines; using Microsoft.ML.OnnxRuntime; using Microsoft.ML.OnnxRuntime.Tensors; -using OrtForge.AI.Models.Astractions; +using OrtForge.AI.Models.Abstractions; using OrtForge.AI.Models.Models; using OrtForge.AI.Models.Options; diff --git a/OrtForge.AI.MicroBenchmarks/BgeRerankerM3ModelConcurrentBenchmarks.cs b/OrtForge.AI.MicroBenchmarks/BgeRerankerM3ModelConcurrentBenchmarks.cs index a356dae..7298c89 100644 --- a/OrtForge.AI.MicroBenchmarks/BgeRerankerM3ModelConcurrentBenchmarks.cs +++ b/OrtForge.AI.MicroBenchmarks/BgeRerankerM3ModelConcurrentBenchmarks.cs @@ -4,7 +4,7 @@ using BenchmarkDotNet.Engines; using Microsoft.ML.OnnxRuntime; using Microsoft.ML.OnnxRuntime.Tensors; -using OrtForge.AI.Models.Astractions; +using OrtForge.AI.Models.Abstractions; using OrtForge.AI.Models.Models; using OrtForge.AI.Models.Options; diff --git a/OrtForge.AI.MicroBenchmarks/OrtForge.AI.MicroBenchmarks.csproj b/OrtForge.AI.MicroBenchmarks/OrtForge.AI.MicroBenchmarks.csproj index f404545..ee18f98 100644 --- a/OrtForge.AI.MicroBenchmarks/OrtForge.AI.MicroBenchmarks.csproj +++ b/OrtForge.AI.MicroBenchmarks/OrtForge.AI.MicroBenchmarks.csproj @@ -15,7 +15,7 @@ - + $(DefineConstants);WINDOWS @@ -33,7 +33,7 @@ - + $(DefineConstants);CUDA @@ -41,7 +41,7 @@ - + diff --git a/OrtForge.AI.MicroBenchmarks/Program.cs b/OrtForge.AI.MicroBenchmarks/Program.cs index f7f6de0..976f700 100755 --- a/OrtForge.AI.MicroBenchmarks/Program.cs +++ b/OrtForge.AI.MicroBenchmarks/Program.cs @@ -1,7 +1,4 @@ -using BenchmarkDotNet.Attributes; using BenchmarkDotNet.Configs; -using BenchmarkDotNet.Environments; -using BenchmarkDotNet.Jobs; using BenchmarkDotNet.Running; namespace OrtForge.AI.MicroBenchmarks; diff --git a/OrtForge.AI.MicroBenchmarks/VectorBenchmarks.cs b/OrtForge.AI.MicroBenchmarks/VectorBenchmarks.cs index f11440a..7fe1a91 100755 --- a/OrtForge.AI.MicroBenchmarks/VectorBenchmarks.cs +++ b/OrtForge.AI.MicroBenchmarks/VectorBenchmarks.cs @@ -42,7 +42,7 @@ public float MagnitudeVectorT() { iterations--; } - var magnitude = (float) Math.Sqrt(System.Numerics.Vector.Sum(buffer)); + var magnitude = (float) Math.Sqrt(Vector.Sum(buffer)); return magnitude; } diff --git a/OrtForge.AI.Models.Astractions/BaseModelOptions.cs b/OrtForge.AI.Models.Abstractions/BaseModelOptions.cs similarity index 92% rename from OrtForge.AI.Models.Astractions/BaseModelOptions.cs rename to OrtForge.AI.Models.Abstractions/BaseModelOptions.cs index 71326b5..88d4d99 100644 --- a/OrtForge.AI.Models.Astractions/BaseModelOptions.cs +++ b/OrtForge.AI.Models.Abstractions/BaseModelOptions.cs @@ -1,4 +1,4 @@ -namespace OrtForge.AI.Models.Astractions; +namespace OrtForge.AI.Models.Abstractions; public class BaseModelOptions { diff --git a/OrtForge.AI.Models.Astractions/ExecutionProvider.cs b/OrtForge.AI.Models.Abstractions/ExecutionProvider.cs similarity index 87% rename from OrtForge.AI.Models.Astractions/ExecutionProvider.cs rename to OrtForge.AI.Models.Abstractions/ExecutionProvider.cs index 4664d60..833e97e 100644 --- a/OrtForge.AI.Models.Astractions/ExecutionProvider.cs +++ b/OrtForge.AI.Models.Abstractions/ExecutionProvider.cs @@ -1,4 +1,4 @@ -namespace OrtForge.AI.Models.Astractions; +namespace OrtForge.AI.Models.Abstractions; [Flags] public enum ExecutionProvider diff --git a/OrtForge.AI.Models.Astractions/Extensions/VectorExtensions.cs b/OrtForge.AI.Models.Abstractions/Extensions/VectorExtensions.cs similarity index 96% rename from OrtForge.AI.Models.Astractions/Extensions/VectorExtensions.cs rename to OrtForge.AI.Models.Abstractions/Extensions/VectorExtensions.cs index 7613667..d25fbbf 100755 --- a/OrtForge.AI.Models.Astractions/Extensions/VectorExtensions.cs +++ b/OrtForge.AI.Models.Abstractions/Extensions/VectorExtensions.cs @@ -3,7 +3,7 @@ using System.Runtime.Intrinsics; using Microsoft.ML.OnnxRuntime; -namespace OrtForge.AI.Models.Astractions.Extensions; +namespace OrtForge.AI.Models.Abstractions.Extensions; public static class VectorExtensions { diff --git a/OrtForge.AI.Models.Astractions/ModelHostBase.cs b/OrtForge.AI.Models.Abstractions/ModelHostBase.cs similarity index 98% rename from OrtForge.AI.Models.Astractions/ModelHostBase.cs rename to OrtForge.AI.Models.Abstractions/ModelHostBase.cs index 50f1db7..b7340c5 100644 --- a/OrtForge.AI.Models.Astractions/ModelHostBase.cs +++ b/OrtForge.AI.Models.Abstractions/ModelHostBase.cs @@ -2,9 +2,9 @@ using Microsoft.ML.OnnxRuntime; using Microsoft.ML.OnnxRuntime.Tensors; using Microsoft.ML.Tokenizers; -using OrtForge.AI.Models.Astractions.Extensions; +using OrtForge.AI.Models.Abstractions.Extensions; -namespace OrtForge.AI.Models.Astractions; +namespace OrtForge.AI.Models.Abstractions; public abstract class ModelHostBase : IDisposable { diff --git a/OrtForge.AI.Models.Astractions/ModelInfo.cs b/OrtForge.AI.Models.Abstractions/ModelInfo.cs similarity index 83% rename from OrtForge.AI.Models.Astractions/ModelInfo.cs rename to OrtForge.AI.Models.Abstractions/ModelInfo.cs index 03c338b..fbbb801 100755 --- a/OrtForge.AI.Models.Astractions/ModelInfo.cs +++ b/OrtForge.AI.Models.Abstractions/ModelInfo.cs @@ -1,4 +1,4 @@ -namespace OrtForge.AI.Models.Astractions; +namespace OrtForge.AI.Models.Abstractions; /// /// Model information structure diff --git a/OrtForge.AI.Models.Astractions/OrtForge.AI.Models.Astractions.csproj b/OrtForge.AI.Models.Abstractions/OrtForge.AI.Models.Abstractions.csproj similarity index 87% rename from OrtForge.AI.Models.Astractions/OrtForge.AI.Models.Astractions.csproj rename to OrtForge.AI.Models.Abstractions/OrtForge.AI.Models.Abstractions.csproj index ab64ac2..ea9827a 100644 --- a/OrtForge.AI.Models.Astractions/OrtForge.AI.Models.Astractions.csproj +++ b/OrtForge.AI.Models.Abstractions/OrtForge.AI.Models.Abstractions.csproj @@ -7,8 +7,8 @@ - - + + diff --git a/OrtForge.AI.Models/Models/BgeM3Model.cs b/OrtForge.AI.Models/Models/BgeM3Model.cs index 8497859..795cdc1 100755 --- a/OrtForge.AI.Models/Models/BgeM3Model.cs +++ b/OrtForge.AI.Models/Models/BgeM3Model.cs @@ -1,7 +1,7 @@ using Microsoft.ML.OnnxRuntime; using Microsoft.ML.OnnxRuntime.Tensors; using Microsoft.ML.Tokenizers; -using OrtForge.AI.Models.Astractions; +using OrtForge.AI.Models.Abstractions; using OrtForge.AI.Models.Options; namespace OrtForge.AI.Models.Models; diff --git a/OrtForge.AI.Models/Models/BgeRerankerM3.cs b/OrtForge.AI.Models/Models/BgeRerankerM3.cs index 1e0089f..277d65f 100755 --- a/OrtForge.AI.Models/Models/BgeRerankerM3.cs +++ b/OrtForge.AI.Models/Models/BgeRerankerM3.cs @@ -1,7 +1,7 @@ using Microsoft.ML.OnnxRuntime; using Microsoft.ML.OnnxRuntime.Tensors; using Microsoft.ML.Tokenizers; -using OrtForge.AI.Models.Astractions; +using OrtForge.AI.Models.Abstractions; using OrtForge.AI.Models.Options; namespace OrtForge.AI.Models.Models; diff --git a/OrtForge.AI.Models/Options/BgeM3Options.cs b/OrtForge.AI.Models/Options/BgeM3Options.cs index 27b9b3b..a814157 100644 --- a/OrtForge.AI.Models/Options/BgeM3Options.cs +++ b/OrtForge.AI.Models/Options/BgeM3Options.cs @@ -1,5 +1,5 @@ using Microsoft.ML.OnnxRuntime.Tensors; -using OrtForge.AI.Models.Astractions; +using OrtForge.AI.Models.Abstractions; namespace OrtForge.AI.Models.Options; diff --git a/OrtForge.AI.Models/OrtForge.AI.Models.csproj b/OrtForge.AI.Models/OrtForge.AI.Models.csproj index 636c905..57a0288 100755 --- a/OrtForge.AI.Models/OrtForge.AI.Models.csproj +++ b/OrtForge.AI.Models/OrtForge.AI.Models.csproj @@ -7,7 +7,7 @@ - + diff --git a/OrtForge.AI.Runtime.CUDA/OrtForge.AI.Runtime.CUDA.csproj b/OrtForge.AI.Runtime.CUDA/OrtForge.AI.Runtime.CUDA.csproj index 041a9b0..dd434e4 100644 --- a/OrtForge.AI.Runtime.CUDA/OrtForge.AI.Runtime.CUDA.csproj +++ b/OrtForge.AI.Runtime.CUDA/OrtForge.AI.Runtime.CUDA.csproj @@ -7,7 +7,7 @@ - + diff --git a/OrtForge.AI.Runtime.ROCm/OrtForge.AI.Runtime.ROCm.csproj b/OrtForge.AI.Runtime.MigraphX/OrtForge.AI.Runtime.MigraphX.csproj similarity index 69% rename from OrtForge.AI.Runtime.ROCm/OrtForge.AI.Runtime.ROCm.csproj rename to OrtForge.AI.Runtime.MigraphX/OrtForge.AI.Runtime.MigraphX.csproj index d4031c2..5337f77 100644 --- a/OrtForge.AI.Runtime.ROCm/OrtForge.AI.Runtime.ROCm.csproj +++ b/OrtForge.AI.Runtime.MigraphX/OrtForge.AI.Runtime.MigraphX.csproj @@ -8,8 +8,8 @@ - - + + diff --git a/OrtForge.AI.UnitTests/AgentOrchestratorHelpersTests.cs b/OrtForge.AI.UnitTests/AgentOrchestratorHelpersTests.cs new file mode 100644 index 0000000..396de61 --- /dev/null +++ b/OrtForge.AI.UnitTests/AgentOrchestratorHelpersTests.cs @@ -0,0 +1,83 @@ +using OrtForge.AI.Agent.Agents; +using OrtForge.AI.Agent.Generation; + +namespace OrtForge.AI.UnitTests; + +public class AgentOrchestratorHelpersTests +{ + [Fact] + public void IsStopToken_RecognizesConfiguredTokens() + { + var config = InferenceConfig.Default; + Assert.True(AgentOrchestrator.IsStopToken(2, config)); + Assert.True(AgentOrchestrator.IsStopToken(0, config)); + Assert.False(AgentOrchestrator.IsStopToken(5, config)); + } + + [Fact] + public void IsStopSequence_DetectsConfiguredSequences() + { + var config = new InferenceConfig { StopSequences = ["", "<|end|>"] }; + Assert.True(AgentOrchestrator.IsStopSequence("helloworld", config)); + Assert.True(AgentOrchestrator.IsStopSequence("test<|end|>", config)); + Assert.False(AgentOrchestrator.IsStopSequence("nothing here", config)); + } +} + +public class ToolCallStateTests +{ + [Fact] + public void ToolCallState_DetectsCompleteToolCall() + { + var state = new ToolCallState(); + state.AppendText("TOOL_CALL\nname: test_tool\nargs: test_args\nEND_TOOL_CALL"); + + Assert.True(state.HasPendingCalls); + var call = state.GetNextPendingCall(); + Assert.NotNull(call); + Assert.Equal("test_tool", call.Name); + Assert.Equal("test_args", call.Arguments); + Assert.Equal(ToolCallStatus.Pending, call.Status); + } + + [Fact] + public void ToolCallState_HandlesIncompleteCall() + { + var state = new ToolCallState(); + state.AppendToken("TOOL_CALL"); + state.AppendToken("\nname: "); + state.AppendToken("test"); + + Assert.False(state.HasPendingCalls); + Assert.True(state.InToolCall); + } + + [Fact] + public void ToolCallState_UpdatesCallStatus() + { + var state = new ToolCallState(); + state.AppendText("TOOL_CALL\nname: test\nargs: args\nEND_TOOL_CALL"); + + var call = state.GetNextPendingCall(); + Assert.NotNull(call); + + state.UpdateCallStatus(call, ToolCallStatus.Executing); + Assert.Equal(ToolCallStatus.Executing, state.Calls[0].Status); + + state.UpdateCallStatus(call, ToolCallStatus.Completed, "result"); + Assert.Equal(ToolCallStatus.Completed, state.Calls[0].Status); + Assert.Equal("result", state.Calls[0].Result); + } + + [Fact] + public void ToolCallState_ResetClearsState() + { + var state = new ToolCallState(); + state.AppendText("TOOL_CALL\nname: test\nargs: args\nEND_TOOL_CALL"); + + Assert.True(state.HasPendingCalls); + state.Reset(); + Assert.False(state.HasPendingCalls); + Assert.False(state.InToolCall); + } +} diff --git a/OrtForge.AI.UnitTests/EmbeddingGenerationTests.cs b/OrtForge.AI.UnitTests/EmbeddingGenerationTests.cs index cfd3947..559f6c3 100755 --- a/OrtForge.AI.UnitTests/EmbeddingGenerationTests.cs +++ b/OrtForge.AI.UnitTests/EmbeddingGenerationTests.cs @@ -1,6 +1,6 @@ using System.Numerics.Tensors; using Microsoft.ML.OnnxRuntime.Tensors; -using OrtForge.AI.Models.Astractions; +using OrtForge.AI.Models.Abstractions; using OrtForge.AI.Models.Models; using OrtForge.AI.Models.Options; using Xunit.Abstractions; diff --git a/OrtForge.AI.UnitTests/InMemoryVectorStoreTests.cs b/OrtForge.AI.UnitTests/InMemoryVectorStoreTests.cs new file mode 100644 index 0000000..996fb48 --- /dev/null +++ b/OrtForge.AI.UnitTests/InMemoryVectorStoreTests.cs @@ -0,0 +1,42 @@ +using OrtForge.AI.Agent.Rag; + +namespace OrtForge.AI.UnitTests; + +public class InMemoryVectorStoreTests +{ + [Fact] + public void Upsert_AddsAndReplacesById() + { + var vs = new InMemoryVectorStore(); + vs.Upsert(new InMemoryVectorStore.Item("a", [1, 0], "Doc A", null)); + vs.Upsert(new InMemoryVectorStore.Item("b", [0, 1], "Doc B", null)); + var top = vs.TopK([1, 0], 2); + Assert.Collection(top, + item => Assert.Equal("a", item.Id), + item => Assert.Equal("b", item.Id)); + vs.Upsert(new InMemoryVectorStore.Item("a", [0, 1], "Doc A2", new Dictionary{{"v","2"}})); + top = vs.TopK([1, 0], 2); + Assert.Equal(2, top.Count); + var ids = top.Select(t => t.Id).ToHashSet(); + Assert.Contains("a", ids); + Assert.Contains("b", ids); + var a = top.First(t => t.Id == "a"); + Assert.Equal("Doc A2", a.Text); + Assert.Equal("2", a.Metadata!["v"]); + } + + [Fact] + public void TopK_ReturnsOrderedByCosineSimilarity() + { + var vs = new InMemoryVectorStore(); + vs.Upsert(new InMemoryVectorStore.Item("x", [1, 0], "X", null)); + vs.Upsert(new InMemoryVectorStore.Item("y", [0.7f, 0.7f], "Y", null)); + vs.Upsert(new InMemoryVectorStore.Item("z", [0, 1], "Z", null)); + var query = new float[] {0.9f, 0.1f}; + var top2 = vs.TopK(query, 2); + Assert.Equal("x", top2[0].Id); + Assert.Equal("y", top2[1].Id); + var top3 = vs.TopK(query, 3); + Assert.Equal(new[]{"x","y","z"}, new[]{top3[0].Id, top3[1].Id, top3[2].Id}); + } +} diff --git a/OrtForge.AI.UnitTests/KvStateTests.cs b/OrtForge.AI.UnitTests/KvStateTests.cs new file mode 100644 index 0000000..839d874 --- /dev/null +++ b/OrtForge.AI.UnitTests/KvStateTests.cs @@ -0,0 +1,72 @@ +using OrtForge.AI.Agent.LLM; + +namespace OrtForge.AI.UnitTests; + +public class KvStateTests +{ + [Fact] + public void KvState_Dispose_ClearsTensorsList() + { + // Arrange - create empty KvState (no actual OrtValues needed for this test) + var tensors = new List(); + var kvState = new KvState(tensors, initialSequenceLength: 10); + + // Act + kvState.Dispose(); + + // Assert + Assert.Empty(kvState.Tensors); + } + + [Fact] + public void KvState_CalculateTotalLengthAfterTokens_ReturnsCorrectLength() + { + // Arrange + var kvState = new KvState([], initialSequenceLength: 10); + + // Act + var totalLength = kvState.CalculateTotalLengthAfterTokens(5); + + // Assert + Assert.Equal(15, totalLength); + } + + [Fact] + public void KvState_AccumulatedSequenceLength_IsSetCorrectly() + { + // Arrange & Act + var kvState = new KvState([], initialSequenceLength: 42); + + // Assert + Assert.Equal(42, kvState.AccumulatedSequenceLength); + } + + [Fact] + public void KvState_CalculateTotalLengthAfterTokens_ThrowsForNegativeTokenCount() + { + // Arrange + var kvState = new KvState([], initialSequenceLength: 10); + + // Act & Assert + Assert.Throws(() => kvState.CalculateTotalLengthAfterTokens(-1)); + } + + [Fact] + public void KvState_Constructor_ThrowsForNegativeSequenceLength() + { + // Act & Assert + Assert.Throws(() => new KvState([], initialSequenceLength: -1)); + } + + [Fact] + public void KvState_EmptyState_HasZeroSequenceLength() + { + // Arrange & Act + var kvState = new KvState([]); + + // Assert + Assert.Equal(0, kvState.AccumulatedSequenceLength); + Assert.Empty(kvState.Tensors); + } +} + diff --git a/OrtForge.AI.UnitTests/LlamaOptimizationsTests.cs b/OrtForge.AI.UnitTests/LlamaOptimizationsTests.cs new file mode 100644 index 0000000..5fa0144 --- /dev/null +++ b/OrtForge.AI.UnitTests/LlamaOptimizationsTests.cs @@ -0,0 +1,79 @@ +using OrtForge.AI.Agent.LLM; + +namespace OrtForge.AI.UnitTests; + +public class LlamaOptimizationsTests +{ + [Fact] + public void CreateOptimalPositionIds_InitialPrompt_ReturnsSequentialIds() + { + // Arrange - initial prompt of 5 tokens, total length 5 + var totalSequenceLength = 5; + var newTokenCount = 5; + + // Act + var positionIds = LlamaOptimizations.CreateOptimalPositionIds(totalSequenceLength, newTokenCount); + + // Assert + Assert.Equal(5, positionIds.Length); + Assert.Equal(new long[] { 0, 1, 2, 3, 4 }, positionIds); + } + + [Fact] + public void CreateOptimalPositionIds_SingleNewToken_ReturnsSinglePosition() + { + // Arrange - 10 tokens already, adding 1 more + var totalSequenceLength = 10; + var newTokenCount = 1; + + // Act + var positionIds = LlamaOptimizations.CreateOptimalPositionIds(totalSequenceLength, newTokenCount); + + // Assert - should return single position ID = 9 (the 10th position, 0-indexed) + Assert.Single(positionIds); + Assert.Equal(9, positionIds[0]); + } + + [Fact] + public void CreateOptimalPositionIds_LengthMatchesNewTokenCount() + { + // Arrange - 50 tokens already, adding 10 more (e.g., new prompt) + var totalSequenceLength = 60; + var newTokenCount = 10; + + // Act + var positionIds = LlamaOptimizations.CreateOptimalPositionIds(totalSequenceLength, newTokenCount); + + // Assert - should be positions 50, 51, 52, ... 59 + Assert.Equal(newTokenCount, positionIds.Length); + for (int i = 0; i < newTokenCount; i++) + { + Assert.Equal(50 + i, positionIds[i]); + } + } + + [Fact] + public void CreateOptimalPositionIds_MultipleGenerationSteps_ReturnsCorrectPositions() + { + // Simulate generation steps + // Step 0: Initial prompt of 5 tokens + var step0Ids = LlamaOptimizations.CreateOptimalPositionIds(5, 5); + Assert.Equal(new long[] { 0, 1, 2, 3, 4 }, step0Ids); + + // Step 1: Generate 1 token, total is 6 + var step1Ids = LlamaOptimizations.CreateOptimalPositionIds(6, 1); + Assert.Single(step1Ids); + Assert.Equal(5, step1Ids[0]); + + // Step 2: Generate 1 token, total is 7 + var step2Ids = LlamaOptimizations.CreateOptimalPositionIds(7, 1); + Assert.Single(step2Ids); + Assert.Equal(6, step2Ids[0]); + + // New turn: Add 3-token prompt, total is 10 + var newTurnIds = LlamaOptimizations.CreateOptimalPositionIds(10, 3); + Assert.Equal(3, newTurnIds.Length); + Assert.Equal(new long[] { 7, 8, 9 }, newTurnIds); + } +} + diff --git a/OrtForge.AI.UnitTests/LlamaSessionTests.cs b/OrtForge.AI.UnitTests/LlamaSessionTests.cs new file mode 100644 index 0000000..142002b --- /dev/null +++ b/OrtForge.AI.UnitTests/LlamaSessionTests.cs @@ -0,0 +1,90 @@ +using OrtForge.AI.Agent.LLM; + +namespace OrtForge.AI.UnitTests; + +/// +/// Tests for LlamaSession and related classes. +/// Note: Full integration tests would require actual ONNX models. +/// +public class LlamaSessionTests +{ + [Fact] + public void StepInputs_Create_WithValidInput_ReturnsStepInputs() + { + // Arrange + var inputIds = new long[] { 1, 2, 3, 4, 5 }; + var kvState = new KvState([]); + + // Act + using var stepInputs = LlamaSession.StepInputs.Create(inputIds, kvState); + + // Assert + Assert.NotNull(stepInputs); + Assert.NotNull(stepInputs.InputIds); + } + + [Fact] + public void StepInputs_Create_WithPositionIds_IncludesPositionIds() + { + // Arrange + var inputIds = new long[] { 1, 2, 3 }; + var positionIds = new long[] { 0, 1, 2 }; + var kvState = new KvState([]); + + // Act + using var stepInputs = LlamaSession.StepInputs.Create(inputIds, kvState, positionIds); + + // Assert + Assert.NotNull(stepInputs); + Assert.NotNull(stepInputs.PositionIds); + } + + [Fact] + public void StepInputs_Create_WithAttentionMask_IncludesAttentionMask() + { + // Arrange + var inputIds = new long[] { 1, 2, 3 }; + var attentionMask = new long[] { 1, 1, 1 }; + var kvState = new KvState([]); + + // Act + using var stepInputs = LlamaSession.StepInputs.Create(inputIds, kvState, null, attentionMask); + + // Assert + Assert.NotNull(stepInputs); + Assert.NotNull(stepInputs.AttentionMask); + } + + [Fact] + public void StepInputs_Dispose_DoesNotThrow() + { + // Arrange + var inputIds = new long[] { 1, 2, 3 }; + var positionIds = new long[] { 0, 1, 2 }; + var attentionMask = new long[] { 1, 1, 1 }; + var kvState = new KvState([]); + + // Act + var stepInputs = LlamaSession.StepInputs.Create(inputIds, kvState, positionIds, attentionMask); + + // Assert - dispose should not throw + var exception = Record.Exception(() => stepInputs.Dispose()); + Assert.Null(exception); + } + + [Fact] + public void StepInputs_Create_EmptyInputIds_StillCreatesValidInputs() + { + // Arrange - edge case with single token + var inputIds = new long[] { 42 }; + var kvState = new KvState([]); + + // Act + using var stepInputs = LlamaSession.StepInputs.Create(inputIds, kvState); + + // Assert + Assert.NotNull(stepInputs); + Assert.NotNull(stepInputs.InputIds); + } +} + diff --git a/OrtForge.AI.UnitTests/OrtForge.AI.UnitTests.csproj b/OrtForge.AI.UnitTests/OrtForge.AI.UnitTests.csproj index 98de54e..b6028b0 100755 --- a/OrtForge.AI.UnitTests/OrtForge.AI.UnitTests.csproj +++ b/OrtForge.AI.UnitTests/OrtForge.AI.UnitTests.csproj @@ -17,7 +17,7 @@ - + $(DefineConstants);WINDOWS @@ -35,7 +35,7 @@ - + $(DefineConstants);CUDA @@ -43,7 +43,7 @@ - + @@ -78,7 +78,9 @@ + + diff --git a/OrtForge.AI.UnitTests/RerankerTests.cs b/OrtForge.AI.UnitTests/RerankerTests.cs index e55f4e9..ee844ff 100755 --- a/OrtForge.AI.UnitTests/RerankerTests.cs +++ b/OrtForge.AI.UnitTests/RerankerTests.cs @@ -1,6 +1,5 @@ -using Microsoft.ML.OnnxRuntime; using Microsoft.ML.OnnxRuntime.Tensors; -using OrtForge.AI.Models.Astractions; +using OrtForge.AI.Models.Abstractions; using OrtForge.AI.Models.Models; using OrtForge.AI.Models.Options; using Xunit.Abstractions; diff --git a/OrtForge.AI.UnitTests/SamplingTests.cs b/OrtForge.AI.UnitTests/SamplingTests.cs new file mode 100644 index 0000000..398e357 --- /dev/null +++ b/OrtForge.AI.UnitTests/SamplingTests.cs @@ -0,0 +1,99 @@ +using OrtForge.AI.Agent.Generation; + +namespace OrtForge.AI.UnitTests; + +public class SamplingTests +{ + [Fact] + public void Greedy_SelectsMaxIndex() + { + var logits = new float[] { -1f, 0.5f, 3.2f, 3.19f }; + var idx = Sampling.Greedy(logits); + Assert.Equal(2, idx); + } + + [Fact] + public void Sample_WithGreedyConfig_EqualsGreedy() + { + var logits = new float[] { 0.1f, 2.5f, -0.5f, 1.0f }; + var greedy = Sampling.Greedy(logits); + var config = InferenceConfig.Greedy; + var idx = Sampling.Sample(logits, config, [], new Random(42)); + Assert.Equal(greedy, idx); + } + + [Fact] + public void Sample_TopK_SamplesOnlyFromTopK() + { + var logits = new float[] { 1f, 2f, 3f, 4f, 5f }; + var config = new InferenceConfig { TopK = 3, Temperature = 1.0, Seed = 123 }; + var rng = new Random(123); + for (int t = 0; t < 100; t++) + { + var idx = Sampling.Sample(logits, config, [], rng); + Assert.Contains(idx, new[] { 2, 3, 4 }); + } + } + + [Fact] + public void Sample_LowTemperature_PrefersMax() + { + var logits = new float[] { 1f, 2f, 3f, 4f, 5f }; + var config = new InferenceConfig { TopK = 5, Temperature = 0.01, Seed = 7 }; + int favored = 0; + var rng = new Random(7); + for (int t = 0; t < 50; t++) + { + var idx = Sampling.Sample(logits, config, [], rng); + if (idx == 4) favored++; + } + Assert.True(favored > 40); + } + + [Fact] + public void Sample_WithRepetitionPenalty_ReducesRepeatedTokens() + { + var logits = new float[] { 1f, 2f, 3f, 4f, 5f }; + var previousTokens = new int[] { 4, 4, 4 }; + var config = new InferenceConfig { RepetitionPenalty = 1.2, TopK = 5, Temperature = 0.1, Seed = 42 }; + + var idx = Sampling.Sample(logits, config, previousTokens.ToList(), new Random(42)); + + Assert.NotEqual(4, idx); + } + + [Fact] + public void Sample_WithTopP_LimitsTokenSelection() + { + var logits = new float[] { 1f, 1f, 1f, 10f, 10f }; + var config = new InferenceConfig { TopP = 0.5, Temperature = 1.0, Seed = 123 }; + var rng = new Random(123); + + for (int t = 0; t < 50; t++) + { + var idx = Sampling.Sample(logits, config, [], rng); + Assert.Contains(idx, new[] { 3, 4 }); + } + } + + [Fact] + public void Sample_WithRepetitionPenaltyOfOne_DoesNotModifyLogits() + { + // Arrange - penalty of 1.0 should be a no-op + var logits = new float[] { 1f, 2f, 3f, 4f, 5f }; + var previousTokens = new int[] { 4, 4, 4 }; + + // Config with penalty = 1.0 (should be no-op) + var configWithPenalty = new InferenceConfig { RepetitionPenalty = 1.0, TopK = 5, Temperature = 0.01, Seed = 42 }; + + // Config without penalty + var configWithoutPenalty = new InferenceConfig { RepetitionPenalty = 0.0, TopK = 5, Temperature = 0.01, Seed = 42 }; + + // Both should behave the same - select token 4 (highest logit) + var idxWithPenalty = Sampling.Sample(logits, configWithPenalty, previousTokens.ToList(), new Random(42)); + var idxWithoutPenalty = Sampling.Sample(logits, configWithoutPenalty, previousTokens.ToList(), new Random(42)); + + Assert.Equal(idxWithoutPenalty, idxWithPenalty); + Assert.Equal(4, idxWithPenalty); // Both should select token 4 (highest logit, unpenalized) + } +} diff --git a/OrtForge.AI.UnitTests/SlidingWindowTests.cs b/OrtForge.AI.UnitTests/SlidingWindowTests.cs new file mode 100644 index 0000000..393718d --- /dev/null +++ b/OrtForge.AI.UnitTests/SlidingWindowTests.cs @@ -0,0 +1,103 @@ +using OrtForge.AI.Agent.Generation; + +namespace OrtForge.AI.UnitTests; + +/// +/// Tests for sliding window token history maintained across conversation turns +/// for repetition penalty purposes. +/// +public class SlidingWindowTests +{ + [Fact] + public void TokenHistory_MaintainsAcrossTurns() + { + // Arrange + var history = new TokenHistory(maxSize: 10); + + // Simulate turn 1 + history.AddTokens([1, 2, 3]); + + // Simulate turn 2 + history.AddTokens([4, 5, 6]); + + // Assert - all tokens should be in history + var tokens = history.GetTokens(); + Assert.Equal(6, tokens.Count); + Assert.Contains(1, tokens); + Assert.Contains(6, tokens); + } + + [Fact] + public void TokenHistory_EnforcesMaxSize() + { + // Arrange + var history = new TokenHistory(maxSize: 5); + + // Add more tokens than max size + history.AddTokens([1, 2, 3, 4, 5, 6, 7]); + + // Assert - should only keep last 5 + var tokens = history.GetTokens(); + Assert.Equal(5, tokens.Count); + Assert.DoesNotContain(1, tokens); + Assert.DoesNotContain(2, tokens); + Assert.Contains(7, tokens); + } + + [Fact] + public void TokenHistory_AddToken_UpdatesHistory() + { + // Arrange + var history = new TokenHistory(maxSize: 3); + + // Act + history.AddToken(1); + history.AddToken(2); + history.AddToken(3); + history.AddToken(4); // Should push out 1 + + // Assert + var tokens = history.GetTokens(); + Assert.Equal(3, tokens.Count); + Assert.DoesNotContain(1, tokens); + Assert.Contains(4, tokens); + } + + [Fact] + public void TokenHistory_Clear_ResetsHistory() + { + // Arrange + var history = new TokenHistory(maxSize: 10); + history.AddTokens([1, 2, 3, 4, 5]); + + // Act + history.Clear(); + + // Assert + Assert.Empty(history.GetTokens()); + } + + [Fact] + public void TokenHistory_DefaultMaxSize() + { + // Arrange & Act - default should be reasonable (128) + var history = new TokenHistory(); + + // Assert + Assert.Equal(128, history.MaxSize); + } + + [Fact] + public void TokenHistory_CountReflectsActualTokens() + { + // Arrange + var history = new TokenHistory(maxSize: 100); + + // Act + history.AddTokens([1, 2, 3]); + + // Assert + Assert.Equal(3, history.Count); + } +} + diff --git a/OrtForge.AI.UnitTests/ToolCallStateTests.cs b/OrtForge.AI.UnitTests/ToolCallStateTests.cs new file mode 100644 index 0000000..6d30a9e --- /dev/null +++ b/OrtForge.AI.UnitTests/ToolCallStateTests.cs @@ -0,0 +1,186 @@ +using OrtForge.AI.Agent.Agents; + +namespace OrtForge.AI.UnitTests; + +/// +/// Tests for the TOOL_CALL/END_TOOL_CALL pattern that matches the prompt format +/// +public class ToolCallStateNewPatternTests +{ + [Fact] + public void AppendToken_WithToolCallMarkers_DetectsToolCall() + { + // Arrange + var state = new ToolCallState(); + var text = @"Some text before +TOOL_CALL +name: search +args: {""query"": ""test""} +END_TOOL_CALL +Some text after"; + + // Act + state.AppendText(text); + + // Assert + Assert.Single(state.Calls); + Assert.Equal("search", state.Calls[0].Name); + Assert.Equal(@"{""query"": ""test""}", state.Calls[0].Arguments); + } + + [Fact] + public void AppendToken_StreamingTokens_DetectsToolCall() + { + // Arrange + var state = new ToolCallState(); + var tokens = new[] + { + "TOOL", + "_CALL", + "\n", + "name: ", + "calculator", + "\nargs: ", + "2+2", + "\nEND", + "_TOOL_CALL" + }; + + // Act + foreach (var token in tokens) + { + state.AppendToken(token); + } + + // Assert + Assert.Single(state.Calls); + Assert.Equal("calculator", state.Calls[0].Name); + Assert.Equal("2+2", state.Calls[0].Arguments); + } + + [Fact] + public void AppendToken_PartialMarker_DoesNotDetectUntilComplete() + { + // Arrange + var state = new ToolCallState(); + + // Act - Append partial content + state.AppendText("TOOL_CALL\nname: test\nargs: foo"); + + // Assert - Should be in tool call but not complete + Assert.True(state.InToolCall); + Assert.Empty(state.Calls); // Not complete yet + + // Complete the tool call + state.AppendText("\nEND_TOOL_CALL"); + + // Assert - Now should be detected + Assert.False(state.InToolCall); + Assert.Single(state.Calls); + } + + [Fact] + public void ParseToolCallContent_ValidContent_ReturnsToolCall() + { + // Arrange + var state = new ToolCallState(); + var content = @"TOOL_CALL +name: fetch_data +args: {""url"": ""https://example.com"", ""method"": ""GET""} +END_TOOL_CALL"; + + // Act + state.AppendText(content); + + // Assert + Assert.Single(state.Calls); + var call = state.Calls[0]; + Assert.Equal("fetch_data", call.Name); + Assert.Equal(@"{""url"": ""https://example.com"", ""method"": ""GET""}", call.Arguments); + Assert.Equal(ToolCallStatus.Pending, call.Status); + Assert.NotEmpty(call.Id); + } + + [Fact] + public void AppendToken_MultipleToolCalls_DetectsAll() + { + // Arrange + var state = new ToolCallState(); + var text = @"First tool: +TOOL_CALL +name: tool1 +args: arg1 +END_TOOL_CALL +Between tools +TOOL_CALL +name: tool2 +args: arg2 +END_TOOL_CALL +After tools"; + + // Act + state.AppendText(text); + + // Assert + Assert.Equal(2, state.Calls.Count); + Assert.Equal("tool1", state.Calls[0].Name); + Assert.Equal("tool2", state.Calls[1].Name); + } + + [Fact] + public void AppendToken_NameOnly_NoArgs_ReturnsEmptyArgs() + { + // Arrange + var state = new ToolCallState(); + var text = @"TOOL_CALL +name: no_args_tool +END_TOOL_CALL"; + + // Act + state.AppendText(text); + + // Assert + Assert.Single(state.Calls); + Assert.Equal("no_args_tool", state.Calls[0].Name); + Assert.Equal(string.Empty, state.Calls[0].Arguments); + } + + [Fact] + public void Reset_ClearsAllState() + { + // Arrange + var state = new ToolCallState(); + state.AppendText(@"TOOL_CALL +name: test +args: data +END_TOOL_CALL"); + Assert.Single(state.Calls); + + // Act + state.Reset(); + + // Assert + Assert.Empty(state.Calls); + Assert.False(state.InToolCall); + } + + [Fact] + public void GetNextPendingCall_ReturnsPendingCall() + { + // Arrange + var state = new ToolCallState(); + state.AppendText(@"TOOL_CALL +name: pending_test +args: test +END_TOOL_CALL"); + + // Act + var pending = state.GetNextPendingCall(); + + // Assert + Assert.NotNull(pending); + Assert.Equal("pending_test", pending.Name); + Assert.Equal(ToolCallStatus.Pending, pending.Status); + } +} + diff --git a/OrtForge.sln b/OrtForge.sln index 6471026..3f9a8f9 100755 --- a/OrtForge.sln +++ b/OrtForge.sln @@ -13,9 +13,9 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "docs", "docs", "{63CDC6A4-3 docs\INSTALL_AMD_ROCm.md = docs\INSTALL_AMD_ROCm.md EndProjectSection EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "OrtForge.AI.Models.Astractions", "OrtForge.AI.Models.Astractions\OrtForge.AI.Models.Astractions.csproj", "{40A4313C-6826-4E8D-9A01-DA760DE4CE26}" +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "OrtForge.AI.Models.Abstractions", "OrtForge.AI.Models.Abstractions\OrtForge.AI.Models.Abstractions.csproj", "{40A4313C-6826-4E8D-9A01-DA760DE4CE26}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "OrtForge.AI.Runtime.ROCm", "OrtForge.AI.Runtime.ROCm\OrtForge.AI.Runtime.ROCm.csproj", "{8FF1CB84-3A1F-425A-8E9D-45EF01092236}" +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "OrtForge.AI.Runtime.MigraphX", "OrtForge.AI.Runtime.MigraphX\OrtForge.AI.Runtime.MigraphX.csproj", "{8FF1CB84-3A1F-425A-8E9D-45EF01092236}" EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Files", "Solution Files", "{2683178C-EFDD-4951-B0C4-EE84EF8AFD9C}" ProjectSection(SolutionItems) = preProject @@ -37,6 +37,10 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "scripts", "scripts", "{9854 run_benchmarks_ROCm.sh = run_benchmarks_ROCm.sh EndProjectSection EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "OrtForge.AI.Agent", "OrtForge.AI.Agent\OrtForge.AI.Agent.csproj", "{F9138501-F841-4BFC-9336-C54B75F5AB7D}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "OrtForge.AI.Agent.TestApp", "OrtForge.AI.Agent.TestApp\OrtForge.AI.Agent.TestApp.csproj", "{46B86EBA-7720-43D3-B2ED-FEAAAF85AF07}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -67,9 +71,17 @@ Global {8FF1CB84-3A1F-425A-8E9D-45EF01092236}.Debug|Any CPU.Build.0 = Debug|Any CPU {8FF1CB84-3A1F-425A-8E9D-45EF01092236}.Release|Any CPU.ActiveCfg = Release|Any CPU {8FF1CB84-3A1F-425A-8E9D-45EF01092236}.Release|Any CPU.Build.0 = Release|Any CPU + {F9138501-F841-4BFC-9336-C54B75F5AB7D}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {F9138501-F841-4BFC-9336-C54B75F5AB7D}.Debug|Any CPU.Build.0 = Debug|Any CPU + {F9138501-F841-4BFC-9336-C54B75F5AB7D}.Release|Any CPU.ActiveCfg = Release|Any CPU + {F9138501-F841-4BFC-9336-C54B75F5AB7D}.Release|Any CPU.Build.0 = Release|Any CPU + {46B86EBA-7720-43D3-B2ED-FEAAAF85AF07}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {46B86EBA-7720-43D3-B2ED-FEAAAF85AF07}.Debug|Any CPU.Build.0 = Debug|Any CPU + {46B86EBA-7720-43D3-B2ED-FEAAAF85AF07}.Release|Any CPU.ActiveCfg = Release|Any CPU + {46B86EBA-7720-43D3-B2ED-FEAAAF85AF07}.Release|Any CPU.Build.0 = Release|Any CPU {EA1C56B3-FF6C-4605-BBDB-17CA16E22CDC}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {EA1C56B3-FF6C-4605-BBDB-17CA16E22CDC}.Debug|Any CPU.Build.0 = Debug|Any CPU - {EA1C56B3-FF6C-4605-BBDB-17CA16E22CDC}.Release|Any CPU.ActiveCfg = Release|Any CPU - {EA1C56B3-FF6C-4605-BBDB-17CA16E22CDC}.Release|Any CPU.Build.0 = Release|Any CPU + {EA1C56B3-FF6C-4605-BBDB-17CA16E22CDC}.Debug|Any CPU.Build.0 = Debug|Any CPU + {EA1C56B3-FF6C-4605-BBDB-17CA16E22CDC}.Release|Any CPU.ActiveCfg = Release|Any CPU + {EA1C56B3-FF6C-4605-BBDB-17CA16E22CDC}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection EndGlobal diff --git a/docs/AGENTIC_CALL_FLOW.md b/docs/AGENTIC_CALL_FLOW.md new file mode 100644 index 0000000..6ae6a11 --- /dev/null +++ b/docs/AGENTIC_CALL_FLOW.md @@ -0,0 +1,216 @@ +### ONNX Runtime GenAI: C# to Native Call Flow + +This document diagrams the high-level flow from the C# API down to the native layers, stopping at ONNX Runtime. File and symbol names are shown for orientation. + +Key C# entry points: +- `Model` (loads model/config) +- `Generator`, `GeneratorParams` (token generation) +- `Tokenizer`, `TokenizerStream` (text <-> tokens) +- `MultiModalProcessor` (image/audio preprocessing) + +Native boundaries: +- P/Invoke to `onnxruntime-genai` via `src/csharp/NativeMethods.cs` +- C API in `src/ort_genai_c.h` implemented by `src/ort_genai_c.cpp` +- C++ implementation in `src/models/*.cpp`, `src/generators.cpp`, etc. +- ONNX Runtime boundary: `OrtSession::Create`, `OrtSession::Run`, allocators in `src/models/onnxruntime_api.h`, `src/models/model.cpp` + +--- + +### Component Map (C# → P/Invoke → C API → C++ → ONNX Runtime) + +```mermaid +flowchart LR + subgraph CSharp["C# (Microsoft.ML.OnnxRuntimeGenAI)"] + CS_Model["Model"] + CS_Gen["Generator / GeneratorParams"] + CS_Tok["Tokenizer / TokenizerStream"] + CS_MMP["MultiModalProcessor"] + end + + subgraph PInvoke["P/Invoke (src/csharp/NativeMethods.cs)"] + PINV["[DllImport('onnxruntime-genai')] Oga* functions"] + end + + subgraph CAPI["C API (src/ort_genai_c.h/.cpp)"] + C_OgaCreateModel["OgaCreateModel"] + C_OgaCreateGenerator["OgaCreateGenerator"] + C_OgaTokenizer["OgaCreateTokenizer / OgaTokenizer*"] + C_OgaProcessor["OgaCreateMultiModalProcessor / OgaProcessor*"] + C_OgaGenOps["OgaGenerator_* (AppendTokens/GenerateNextToken/GetLogits)"] + end + + subgraph CPP["C++ Impl (namespace Generators)"] + CPP_Model["Model (src/models/model.cpp)"] + CPP_Gen["Generator (src/generators.cpp)"] + CPP_Tok["Tokenizer"] + CPP_Proc["MultiModalProcessor"] + end + + subgraph ORT["ONNX Runtime Boundary"] + ORT_Session["OrtSession::Create / Run"] + ORT_Allocs["Ort::Allocator, OrtMemoryInfo"] + end + + CS_Model --> PINV --> C_OgaCreateModel --> CPP_Model --> ORT_Session + CS_Gen --> PINV --> C_OgaCreateGenerator --> CPP_Gen --> ORT_Session + CS_Tok --> PINV --> C_OgaTokenizer --> CPP_Tok + CS_MMP --> PINV --> C_OgaProcessor --> CPP_Proc + CS_Gen -. runtime ops .-> PINV -.-> C_OgaGenOps -.-> CPP_Gen -.-> ORT_Allocs +``` + +--- + +### Model Construction Flow + +Relevant code: +- `src/csharp/Model.cs` → `NativeMethods.OgaCreateModel` +- `src/ort_genai_c.cpp: OgaCreateModel` → `OgaCreateModelWithRuntimeSettings` +- `Generators::CreateModel` → `Model::CreateSessionOptions`, `Model::CreateSession`, `OrtSession::Create` + +```mermaid +sequenceDiagram + autonumber + participant C# as C# Model + participant P as P/Invoke OgaCreateModel + participant C as C API (ort_genai_c.cpp) + participant CPP as Generators::Model + participant ORT as ONNX Runtime + + C#->>P: OgaCreateModel(configPath) + P->>C: OgaCreateModel + C->>C: OgaCreateModelWithRuntimeSettings(...) + C->>CPP: Generators::CreateModel(GetOrtEnv(), ...) + CPP->>CPP: CreateSessionOptionsFromConfig(...) + CPP->>ORT: OrtSession::Create(...) + ORT-->>CPP: OrtSession* + CPP-->>C: shared OgaModel + C-->>C#: IntPtr model handle +``` + +--- + +### Generation Loop Flow + +Relevant code: +- `src/csharp/Generator.cs` → `NativeMethods.OgaCreateGenerator` +- `src/ort_genai_c.cpp: OgaCreateGenerator`, `OgaGenerator_*` +- `Generators::Generator::GenerateNextToken`, `Model::Run` +- ORT calls: `OrtSession::Run` + +```mermaid +sequenceDiagram + autonumber + participant C# as C# Generator + participant P as P/Invoke Oga* + participant C as C API (ort_genai_c.cpp) + participant CPP as Generators::Generator/Model + participant ORT as ONNX Runtime + + C#->>P: OgaCreateGenerator(model, params) + P->>C: OgaCreateGenerator + C->>CPP: CreateGenerator(model, params) + CPP-->>C#: IntPtr generator handle + + loop per step + C#->>P: OgaGenerator_AppendTokens / _SetInputs + P->>C: OgaGenerator_* + C->>CPP: generator->AppendTokens / SetInputs + C#->>P: OgaGenerator_GenerateNextToken + P->>C: OgaGenerator_GenerateNextToken + C->>CPP: generator->GenerateNextToken() + CPP->>CPP: model->Run(...) + CPP->>ORT: OrtSession::Run(inputs, outputs) + ORT-->>CPP: logits/output OrtValues + CPP-->>C: expose logits/next tokens via accessors + C#->>P: OgaGenerator_GetNextTokens / _GetLogits + P->>C: OgaGenerator_* getters + C-->>C#: tokens/logits (CPU memory) + end +``` + +--- + +### Tokenizer Encode/Decode Flow + +Relevant code: +- `src/csharp/Tokenizer.cs` → `OgaCreateTokenizer`, `OgaTokenizerEncode`, `OgaTokenizerDecode` +- `src/ort_genai_c.cpp: OgaCreateTokenizer`, `OgaTokenizer*` +- `Generators::Tokenizer` + +```mermaid +sequenceDiagram + autonumber + participant C# as C# Tokenizer + participant P as P/Invoke OgaTokenizer* + participant C as C API (ort_genai_c.cpp) + participant CPP as Generators::Tokenizer + + C#->>P: OgaCreateTokenizer(model) + P->>C: OgaCreateTokenizer + C->>CPP: model->CreateTokenizer() + CPP-->>C#: IntPtr tokenizer handle + + C#->>P: OgaTokenizerEncode(str) + P->>C: OgaTokenizerEncode + C->>CPP: tokenizer->Encode(str) + CPP-->>C#: token ids + + C#->>P: OgaTokenizerDecode(tokens) + P->>C: OgaTokenizerDecode + C->>CPP: tokenizer->Decode(tokens) + CPP-->>C#: string +``` + +--- + +### MultiModal Processor (Images/Audio → NamedTensors) + +Relevant code: +- `src/csharp/MultiModalProcessor.cs` → `OgaCreateMultiModalProcessor`, `OgaProcessorProcess*` +- `src/ort_genai_c.cpp: OgaCreateMultiModalProcessor`, `OgaProcessorProcess*` +- `Generators::MultiModalProcessor` + +```mermaid +sequenceDiagram + autonumber + participant C# as C# MultiModalProcessor + participant P as P/Invoke OgaProcessor* + participant C as C API (ort_genai_c.cpp) + participant CPP as Generators::MultiModalProcessor + + C#->>P: OgaCreateMultiModalProcessor(model) + P->>C: OgaCreateMultiModalProcessor + C->>CPP: model->CreateMultiModalProcessor() + CPP-->>C#: IntPtr processor handle + + C#->>P: OgaProcessorProcessImages(prompt, images) + P->>C: OgaProcessorProcessImages + C->>CPP: processor->Process(...) + CPP-->>C#: NamedTensors + C-->>C#: IntPtr named tensors +``` + +--- + +### Error Handling (Result pattern) + +Errors from native calls surface via `OgaResult`: +- C# wrappers call `Result.VerifySuccess(NativeMethods.Oga...(...))` +- C API returns `OgaResult*` on failure; message via `OgaResultGetError` +- Typical C entry: `OGA_TRY`/`OGA_CATCH` in `src/ort_genai_c.cpp` + +```mermaid +flowchart LR + CAPI["C API call"] -->|throw std::exception| CATCH["OGA_CATCH → make OgaResult(error)"] + CATCH --> CS["C# Result.VerifySuccess → throw OnnxRuntimeGenAIException(message)"] +``` + +--- + +### Stopping Boundary + +These diagrams stop at ONNX Runtime calls within the native layer: +- `OrtSession::Create` and `OrtSession::Run` in `src/models/model.cpp` +- Allocators and device interfaces in `src/models/onnxruntime_api.h` + + diff --git a/docs/INSTALL_AMD_ROCm.md b/docs/INSTALL_AMD_ROCm.md index 6775b57..5143b91 100644 --- a/docs/INSTALL_AMD_ROCm.md +++ b/docs/INSTALL_AMD_ROCm.md @@ -41,13 +41,14 @@ Considering the above, choose your targets from the beginning. I recommend build Clone repo ```bash git clone --recursive https://github.com/ROCm/onnxruntime.git -git checkout tags/v1.22.1 cd onnxruntime +git checkout tags/v1.22.1 ``` Build for .NET only to run models ```bash -./build.sh --update --build --config Release --build_nuget --parallel --use_rocm --rocm_home /opt/rocm --skip_tests +./build.sh --update --config Release --build_nuget --parallel --use_migraphx --migraphx_home /opt/rocm --skip_tests +./build.sh --build --config Release --build_nuget --parallel --use_migraphx --migraphx_home /opt/rocm --skip_tests ``` Build for .NET and for Python stack with PyTorch and any other toolset that may utilize GPU accelerators on AMD @@ -58,7 +59,8 @@ source ./bin/activate pip install 'cmake>=3.28,<4' pip install -r requirements.txt pip install setuptools -./build.sh --update --build --config Release --build_wheel --build_nuget --parallel --use_rocm --rocm_home /opt/rocm --skip_tests +./build.sh --update --config Release --build_wheel --build_nuget --parallel --use_migraphx --migraphx_home /opt/rocm --skip_tests +./build.sh --build --config Release --build_wheel --build_nuget --parallel --use_migraphx --migraphx_home /opt/rocm --skip_tests ``` Install wheel for python to use in the venv diff --git a/docs/ONNX_AGENT_ALGORITHM.md b/docs/ONNX_AGENT_ALGORITHM.md new file mode 100644 index 0000000..ca969a4 --- /dev/null +++ b/docs/ONNX_AGENT_ALGORITHM.md @@ -0,0 +1,229 @@ +### Agent Chat on Pure ONNX Runtime: Top-Down Algorithm + +Goal: Outline how to implement an agentic chat loop using only ONNX Runtime (ORT) sessions for all model inference (LLM generation, embeddings, reranking), plus ordinary host code for memory, tools, and control flow. + +Assumptions: +- Pre/post-processing (tokenization, detokenization, tool I/O marshalling) is implemented in host code. +- All neural inference is done via ORT `OrtSession::Run` on ONNX models: LLM, embedding model, reranker or tool classifier (optional), vision/audio encoders (optional). + +--- + +### Components (Top-Down) + +```mermaid +flowchart TB + subgraph App["Application / Chat Service"] + UI["Chat UI / HTTP API"] + Orchestrator["Conversation Orchestrator (host code)"] + end + + subgraph Memory["Memory"] + ConvLog["Conversation Store (structured logs)"] + VecStore["Vector Index (ANN)"] + end + + subgraph Tools["Tools (host code)"] + T1["HTTP/DB/FS APIs"] + TAdapters["Tool Adapters (schema <-> JSON)"] + end + + subgraph ORT["ONNX Runtime Inference"] + LLM["LLM Session (Decoder/Seq2Seq)"] + Embed["Embedding Session (text/dual)\nfor retrieval/memory"] + Rerank["Reranker/Classifier (optional)"] + Vision["Vision/Audio Encoders (optional)"] + end + + UI --> Orchestrator + Orchestrator <---> ConvLog + Orchestrator <---> VecStore + Orchestrator -.-> TAdapters -.-> T1 + Orchestrator --> Embed + Orchestrator --> Rerank + Orchestrator --> Vision + Orchestrator <--> LLM +``` + +--- + +### One Chat Turn (with Tools and Memory) + +```mermaid +sequenceDiagram + autonumber + participant C as Client/UI + participant O as Orchestrator (host) + participant MEM as Memory (ConvLog/VecStore) + participant EMB as ORT Embedding Session + participant L as ORT LLM Session + participant T as Tools (Adapters -> Tool Impl) + + C->>O: send user message + O->>MEM: fetch recent convo turns + O->>EMB: Run() to embed user query + MEM-->>O: retrieve top-k docs via ANN + O->>L: Build prompt+context -> token IDs -> Run(step): logits → token + note right of L: Streaming loop: step-wise Run() + decode + + alt model suggests tool call (via structured output or function tokens) + O->>T: parse tool args -> call tool + T-->>O: tool result (JSON/text) + O->>L: Append tool result to context -> continue Run(step) + else + O-->>C: stream tokens as assistant reply + end + + O->>EMB: Run() to embed chunks of final answer (optional) + O->>MEM: write convo turn + tool results, update VecStore with embeddings + O-->>C: done +``` + +--- + +### ORT Usage: Sessions and Runs + +```mermaid +flowchart LR + subgraph Setup["Initialization (once per process)"] + Env["Create OrtEnv"] + Opts["Create OrtSessionOptions (EPs, threads, graph opts)"] + SLLM["OrtSession::Create(LLM.onnx, Opts)"] + SEmb["OrtSession::Create(Embedding.onnx, Opts)"] + SRerank["OrtSession::Create(Reranker.onnx, Opts)"] + SVision["OrtSession::Create(Encoders.onnx, Opts)"] + end + subgraph Turn["Per-turn Inference"] + Prep["Prepare inputs: token IDs, kv-cache, masks"] + RunStep["OrtSession::Run(inputs)-> logits"] + Sample["Sampling (host): greedy/top-k/top-p"] + Update["Append next token; update kv-cache"] + end + + Env --> Opts --> SLLM + Opts --> SEmb --> SRerank --> SVision + SLLM --> RunStep --> Sample --> Update --> RunStep +``` + +Inputs/Outputs (typical): +- LLM inputs: `input_ids`, `position_ids`, `attention_mask`, `past_key_values` (kv-cache tensors per layer) +- LLM outputs: `logits` (and updated `present_key_values`) +- Embedding inputs: tokenized text; outputs: dense vector(s) + +--- + +### Generation Loop (Step-wise Decoding with ORT) + +```mermaid +sequenceDiagram + autonumber + participant Host as Host Code + participant LLM as ORT LLM Session + + Host->>Host: tokenize(prompt+context) -> input_ids + Host->>LLM: Run({input_ids, masks, kv_cache=None}) + LLM-->>Host: logits, kv_cache + loop until stop + Host->>Host: sample next token from logits + Host->>Host: append to input, update attention_mask + Host->>LLM: Run({next_token, masks, kv_cache}) + LLM-->>Host: logits, kv_cache + Host->>Host: stream decoded token (optional) + alt stop token or max tokens + Host-->>Host: break + end + end +``` + +Sampling is host-implemented (no ORT call): greedy, top-k/top-p, temperature, repetition penalty, etc. KV-cache routing is model-dependent; with ORT you pass and receive the cache tensors each step. + +--- + +### Tool Use Decision Paths (Options) + +```mermaid +flowchart TB + A["LLM emits JSON/function-call tokens"] -->|Parse| B["Extract tool name + args"] + A2["Classifier/Reranker (ORT) \n decides tool vs answer"] --> B + B --> C["Execute tool (host)"] --> D["Summarize result"] + D --> E["Append to context and continue generation via LLM Run()"] +``` + +Implementation choices: +- Structured output via constrained decoding (enforce a JSON schema at sampling time, host-side) +- Separate ORT classifier to decide if a tool call is needed + +--- + +### Retrieval-Augmented Generation (RAG) with ORT + +```mermaid +sequenceDiagram + autonumber + participant O as Orchestrator + participant EMB as ORT Embedding Session + participant V as Vector Index (ANN) + participant L as ORT LLM Session + + O->>EMB: Run() embed(user query) + EMB-->>O: query vector + O->>V: ANN top-k search + V-->>O: docs/passages + O->>O: construct prompt with citations + O->>L: Run() step-wise generation + L-->>O: answer tokens +``` + +Write-back: +- Optionally embed user message and assistant answer with `EMB.Run()` and upsert to `V` for long-term memory. + +--- + +### Memory Write-Back and Summarization + +```mermaid +flowchart LR + A["Turn transcript"] --> B["Summarize (LLM Run or rules)"] --> C["Chunk & Embed (EMB Run)"] --> D["Upsert to VecStore"] + A --> E["Store raw turn in ConvLog"] +``` + +--- + +### Minimal Pseudocode (Host) + +```text +initialize OrtEnv +create sessions: llm_sess, emb_sess, (optional) rerank_sess, vision_sess + +for each chat turn: + convo_ctx = memory.fetch_recent() + retrieved = retrieve_with_embeddings(emb_sess, user_msg) + prompt = format_prompt(convo_ctx, retrieved, user_msg) + tokens, kv = tokenize(prompt), None + + while not stop: + logits, kv = llm_sess.Run(inputs(tokens.last, kv, masks)) + next_token = sample(logits) + stream(next_token) + if is_function_token(next_token): + call = parse_function(tokens) + tool_result = execute_tool(call) + tokens += tokenize(format_tool_result(tool_result)) + if stopping_condition(tokens): break + + answer = detokenize(tokens.new_segment) + memory.write_back(user_msg, answer, tool_results) + if long_term: + emb = emb_sess.Run(tokenize(answer)) + vecstore.upsert(emb, metadata) +``` + +--- + +### Notes and Tips +- Manage kv-cache tensors explicitly per model; shape/layout are model-architecture specific. +- For streaming, run step-wise decoding and surface decoded tokens as they arrive. +- Control sampling determinism by fixing seed and using greedy/beam search. +- For multi-modal inputs, run encoder sessions (vision/audio) with ORT to produce embeddings/features, then feed into the LLM session. +- For throughput, batch multiple conversations if model supports batching; maintain separate kv-cache per sequence. + + diff --git a/docs/RDNA3_GPU_COMPATIBILITY.md b/docs/RDNA3_GPU_COMPATIBILITY.md new file mode 100644 index 0000000..f4c0d19 --- /dev/null +++ b/docs/RDNA3_GPU_COMPATIBILITY.md @@ -0,0 +1,106 @@ +# RDNA3 GPU Compatibility Guide + +## Problem Overview + +The `GroupQueryAttention` operator in ONNX Runtime ROCm is optimized specifically for AMD's CDNA2 and CDNA3 data center architectures (MI250X, MI300 series). Consumer RDNA3 GPUs like the **RX 7900 XTX** are not supported by this operator, resulting in the following errors: + +``` +GroupQueryAttention currently only supports ck_tile fmha backend which only supports CDNA2 and CDNA3 archs. +GroupQueryAttention running on an unsuppoted GPU may result in hipErrorNoBinaryForGpu or hipErrorSharedObjectInitFailedshared error. +``` + +## Architecture Differences + +- **RDNA3**: Consumer gaming GPU architecture (RX 7900 XTX, RX 7800 XT, etc.) +- **CDNA2/CDNA3**: Data center compute architectures (MI250X, MI300 series, etc.) + +## Solutions + +### Option 1: Environment Variable Override (Recommended) + +Try this first - it tricks ROCm into thinking your RDNA3 GPU is a CDNA3 GPU: + +```bash +export HSA_OVERRIDE_GFX_VERSION=10.3.0 +# Then run your application +``` + +### Option 2: Use RDNA3 Compatible Mode (Built-in) + +The OrtRuntimeFactory now includes RDNA3 compatibility mode by default. You can also explicitly choose: + +```csharp +// Explicit RDNA3 compatibility mode +var session = OrtRuntimeFactory.CreateSession(modelPath, GpuCompatibilityMode.RDNA3Compatible); + +// Or CPU-only for maximum compatibility +var session = OrtRuntimeFactory.CreateSession(modelPath, GpuCompatibilityMode.CpuOnly); + +// Or standard mode for CDNA2/CDNA3 GPUs +var session = OrtRuntimeFactory.CreateSession(modelPath, GpuCompatibilityMode.Standard); +``` + +### Option 3: Limit ROCm Visibility + +If you have multiple GPUs and some are unsupported: + +```bash +export HIP_VISIBLE_DEVICES=0 # Only use first GPU +export ROCR_VISIBLE_DEVICES="0,GPU-your-gpu-uuid" +``` + +## Performance Expectations + +| Mode | GPU Usage | CPU Usage | Performance | Compatibility | +|------|-----------|-----------|-------------|---------------| +| Standard | Full | Fallback | Best | CDNA2/3 only | +| RDNA3Compatible | Partial | Fallback | Good | RDNA3 + CDNA | +| CpuOnly | None | Full | Slower | Universal | + +## Compatibility Settings Explained + +### RDNA3Compatible Mode +- Uses `GraphOptimizationLevel.ORT_ENABLE_BASIC` to avoid problematic operator fusions +- Maintains ROCm + CPU execution provider setup for automatic fallback +- Allows unsupported operators (like GroupQueryAttention) to fall back to CPU +- Maintains GPU acceleration for supported operations + +### What Runs Where +- **GPU (ROCm)**: Matrix operations, embeddings, most computations +- **CPU (Fallback)**: GroupQueryAttention operators, unsupported ops +- **Hybrid**: Tensors automatically transferred between devices + +## Troubleshooting + +### If you still get errors: +1. Verify ROCm installation: `rocminfo` +2. Check GPU visibility: `echo $HIP_VISIBLE_DEVICES` +3. Try CPU-only mode for testing +4. Enable ONNX Runtime logging for detailed operator placement + +### Performance Optimization +- Use Float16 models when possible (faster on GPU) +- Monitor GPU utilization: `rocm-smi` +- Consider batch size adjustments for RDNA3 + +## Model Compatibility + +| Model Type | RDNA3 Compatibility | Notes | +|------------|-------------------|-------| +| Llama 3.2 | ✅ Good | Uses GQA, benefits from hybrid execution | +| Llama 3.1 | ✅ Good | Uses GQA, benefits from hybrid execution | +| BGE-M3 | ✅ Excellent | No GQA operators | +| Reranker | ✅ Excellent | No GQA operators | + +## Future Improvements + +AMD is working on broader RDNA support in ROCm. Monitor these repositories: +- [ROCm ONNX Runtime](https://github.com/ROCm/onnxruntime) +- [Composable Kernels](https://github.com/ROCm/composable_kernel) + +## Getting Help + +If you continue experiencing issues: +1. Check ROCm version compatibility +2. Verify your ONNX model doesn't require CDNA-specific features +3. Consider using models exported specifically for RDNA3 diff --git a/models/01_export_model.sh b/models/01_export_model.sh new file mode 100755 index 0000000..2709fb1 --- /dev/null +++ b/models/01_export_model.sh @@ -0,0 +1,100 @@ +#!/bin/bash +# ============================================================================= +# 01_export_model.sh - Export HuggingFace model to ONNX for Inference +# ============================================================================= +# Usage: ./01_export_model.sh [options] +# +# Custom ONNX export with KV cache support using modern torch.export. +# Does NOT require optimum library. +# +# Options: +# --opset ONNX opset version (default: 21) +# --batch Batch size (default: 1) +# --no-kv-cache Disable KV cache (not recommended for inference) +# --fp32 Export in FP32 instead of FP16 +# --help Show this help +# +# Defaults optimized for LLM inference: +# - KV cache: ENABLED (essential for efficient autoregressive generation) +# - Precision: FP16 (faster, lower memory) +# - Shapes: Dynamic (any batch/sequence length) +# +# Requirements: +# pip install torch transformers onnx +# +# Examples: +# ./01_export_model.sh ./Llama3.1-8B-Instruct/hf ./onnx +# ./01_export_model.sh ./model/hf ./onnx --opset 21 +# ============================================================================= + +set -e + +# ============================================================================= +# Parse arguments - DEFAULTS OPTIMIZED FOR INFERENCE +# ============================================================================= +POSITIONAL=() +OPSET_VERSION="21" +BATCH_SIZE=1 +WITH_KV_CACHE=true +USE_FP16=true + +while [[ $# -gt 0 ]]; do + case $1 in + --opset) + OPSET_VERSION="$2" + shift 2 + ;; + --batch) + BATCH_SIZE="$2" + shift 2 + ;; + --no-kv-cache) + WITH_KV_CACHE=false + shift + ;; + --fp32) + USE_FP16=false + shift + ;; + --help|-h) + head -30 "$0" | tail -27 + exit 0 + ;; + -*) + echo "Unknown option: $1" + exit 1 + ;; + *) + POSITIONAL+=("$1") + shift + ;; + esac +done + +set -- "${POSITIONAL[@]}" + +MODEL_PATH="${1:?Usage: $0 [options]}" +OUTPUT_DIR="${2:?Usage: $0 [options]}" + +echo "==============================================" +echo "ONNX Model Export (Modern torch.export)" +echo "==============================================" +echo "Model path: $MODEL_PATH" +echo "Output dir: $OUTPUT_DIR" +echo "Opset version: $OPSET_VERSION" +echo "Precision: $([ "$USE_FP16" = true ] && echo 'FP16' || echo 'FP32')" +echo "KV cache: $([ "$WITH_KV_CACHE" = true ] && echo 'ENABLED ✓' || echo 'disabled')" +echo "==============================================" + +mkdir -p "$OUTPUT_DIR" + +# Export variables for Python +export MODEL_PATH OUTPUT_DIR OPSET_VERSION USE_FP16 WITH_KV_CACHE + +# Run Python script +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +python3 "$SCRIPT_DIR/py/export_model.py" + +echo "" +echo "Output files:" +ls -lh "$OUTPUT_DIR" diff --git a/models/02_fix_external_data.sh b/models/02_fix_external_data.sh new file mode 100755 index 0000000..188bf06 --- /dev/null +++ b/models/02_fix_external_data.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# ============================================================================= +# 02_fix_external_data.sh - Convert large ONNX model to use external data file +# ============================================================================= +# Required for models > 2GB due to protobuf limits +# Usage: ./02_fix_external_data.sh +# Example: ./02_fix_external_data.sh ./Llama3.1-8B-Instruct/onnx/model.onnx +# ============================================================================= + +set -e + +MODEL_FILE="${1:?Usage: $0 }" + +if [ ! -f "$MODEL_FILE" ]; then + echo "Error: File not found: $MODEL_FILE" + exit 1 +fi + +OUTPUT_DIR=$(dirname "$MODEL_FILE") +BASENAME=$(basename "$MODEL_FILE" .onnx) +EXTERNAL_DATA_FILE="${BASENAME}.onnx.data" + +echo "==============================================" +echo "Fix External Data" +echo "==============================================" +echo "Model file: $MODEL_FILE" +echo "External data: $OUTPUT_DIR/$EXTERNAL_DATA_FILE" +echo "==============================================" + +# Check file size +FILE_SIZE=$(stat -c%s "$MODEL_FILE") +FILE_SIZE_GB=$(echo "scale=2; $FILE_SIZE / 1024 / 1024 / 1024" | bc) +echo "Current file size: ${FILE_SIZE_GB} GB" + +# Export variables for Python +export MODEL_FILE EXTERNAL_DATA_FILE FILE_SIZE + +# Run Python script +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +python3 "$SCRIPT_DIR/py/fix_external_data.py" + +echo "" +echo "Output files:" +ls -lh "$OUTPUT_DIR"/${BASENAME}* + diff --git a/models/03_validate_model.sh b/models/03_validate_model.sh new file mode 100755 index 0000000..729e3d8 --- /dev/null +++ b/models/03_validate_model.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# ============================================================================= +# 03_validate_model.sh - Validate ONNX model +# ============================================================================= +# Usage: ./03_validate_model.sh +# Example: ./03_validate_model.sh ./Llama3.1-8B-Instruct/onnx/model.onnx +# ============================================================================= + +set -e + +MODEL_FILE="${1:?Usage: $0 }" + +if [ ! -f "$MODEL_FILE" ]; then + echo "Error: File not found: $MODEL_FILE" + exit 1 +fi + +echo "==============================================" +echo "Validate ONNX Model" +echo "==============================================" +echo "Model: $MODEL_FILE" +echo "==============================================" + +# Export variables for Python +export MODEL_FILE + +# Run Python script +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +python3 "$SCRIPT_DIR/py/validate_model.py" + diff --git a/models/04_optimize_model.sh b/models/04_optimize_model.sh new file mode 100755 index 0000000..80f00d6 --- /dev/null +++ b/models/04_optimize_model.sh @@ -0,0 +1,174 @@ +#!/bin/bash +# ============================================================================= +# 04_optimize_model.sh - Optimize ONNX model for ONNX Runtime inference +# ============================================================================= +# Usage: ./04_optimize_model.sh [model_type] +# +# This script optimizes ONNX models for ONNX Runtime execution (CPU or GPU EP). +# It fuses attention patterns into efficient operators (MultiHeadAttention/GQA) +# which MIGraphX can then accelerate with Flash Attention kernels. +# +# Environment Variables: +# SKIP_FP16=true - Skip FP16 conversion (for quantized models) +# OPT_LEVEL=<0-2> - Optimization level (default: 1) +# USE_GPU=true - Use GPU for optimization (enables more fusions) +# ATTENTION_TYPE= - Force attention type: MultiHeadAttention, GroupQueryAttention +# +# Model parameters are auto-detected from config.json in the model directory. +# ============================================================================= + +set -e + +INPUT_FILE="${1:?Usage: $0 [model_type]}" +OUTPUT_FILE="${2:?Usage: $0 [model_type]}" +MODEL_TYPE="${3:-gpt_neox}" # gpt_neox is compatible with LLaMA + +if [ ! -f "$INPUT_FILE" ]; then + echo "Error: File not found: $INPUT_FILE" + exit 1 +fi + +INPUT_DIR=$(dirname "$INPUT_FILE") +INPUT_BASE=$(basename "$INPUT_FILE" .onnx) + +# Settings from environment +SKIP_FP16="${SKIP_FP16:-false}" +OPT_LEVEL="${OPT_LEVEL:-1}" + +# ============================================================================= +# Auto-detect model configuration +# ============================================================================= +CONFIG_FILE="$INPUT_DIR/config.json" +if [ -f "$CONFIG_FILE" ]; then + echo "Auto-detecting model parameters from config.json..." + + DETECTED_PARAMS=$(python3 << EOF +import json +with open("$CONFIG_FILE", "r") as f: + config = json.load(f) + +hidden_size = config.get("hidden_size", 4096) +num_heads = config.get("num_attention_heads", 32) +num_kv_heads = config.get("num_key_value_heads", num_heads) +num_layers = config.get("num_hidden_layers", 32) + +# Model variant +variants = {2048: "Llama_3.2_1B", 3072: "Llama_3.2_3B", 4096: "Llama_3.1_8B", + 8192: "Llama_3.1_70B", 16384: "Llama_3.1_405B"} +variant = variants.get(hidden_size, f"Unknown_{hidden_size}") + +print(f'MODEL_VARIANT="{variant}"') +print(f'NUM_HEADS="{num_heads}"') +print(f'HIDDEN_SIZE="{hidden_size}"') +print(f'NUM_KV_HEADS="{num_kv_heads}"') +print(f'NUM_LAYERS="{num_layers}"') +EOF +) + eval "$DETECTED_PARAMS" +else + echo "No config.json found, using defaults..." + NUM_HEADS="32" + HIDDEN_SIZE="4096" + MODEL_VARIANT="Unknown" +fi + +# ============================================================================= +# Check for quantized models (skip FP16) +# ============================================================================= +IS_QUANTIZED=false +if [[ "$INPUT_BASE" == *"int4"* ]] || [[ "$INPUT_BASE" == *"int8"* ]]; then + IS_QUANTIZED=true + SKIP_FP16=true +fi + +# Check for quantization ops in model +if [ "$IS_QUANTIZED" = false ]; then + QUANT_CHECK=$(python3 -c " +import onnx +model = onnx.load('$INPUT_FILE', load_external_data=False) +quant_ops = {'MatMulNBits', 'QLinearMatMul', 'MatMulInteger', 'DequantizeLinear'} +print('QUANTIZED' if set(n.op_type for n in model.graph.node) & quant_ops else '') +" 2>/dev/null || echo "") + [ "$QUANT_CHECK" = "QUANTIZED" ] && IS_QUANTIZED=true && SKIP_FP16=true +fi + +# ============================================================================= +# Print configuration +# ============================================================================= +echo "" +echo "==============================================" +echo "Optimize ONNX Model for ONNX Runtime" +echo "==============================================" +echo "Input: $INPUT_FILE" +echo "Output: $OUTPUT_FILE" +echo "Model: $MODEL_VARIANT" +echo "Heads: $NUM_HEADS (KV: ${NUM_KV_HEADS:-$NUM_HEADS})" +echo "Hidden size: $HIDDEN_SIZE" +echo "----------------------------------------------" +echo "FP16: $([ "$SKIP_FP16" = true ] && echo 'disabled' || echo 'enabled')" +echo "Quantized: $([ "$IS_QUANTIZED" = true ] && echo 'yes' || echo 'no')" +echo "Opt level: $OPT_LEVEL" +echo "==============================================" +echo "" + +# ============================================================================= +# Check external data +# ============================================================================= +USE_EXTERNAL="" +if [ -f "$INPUT_DIR/${INPUT_BASE}.onnx.data" ] || [ -f "$INPUT_DIR/${INPUT_BASE}.onnx_data" ]; then + echo "External data detected, will preserve in output..." + USE_EXTERNAL="--use_external_data_format" +fi + +# Check for oversized model +ONNX_SIZE=$(stat -c%s "$INPUT_FILE" 2>/dev/null || stat -f%z "$INPUT_FILE" 2>/dev/null || echo "0") +if [ "$ONNX_SIZE" -gt 2147483648 ]; then + echo "⚠️ ONNX file exceeds 2GB protobuf limit!" + echo " Run: ./02_fix_external_data.sh $INPUT_FILE" + exit 1 +fi + +# ============================================================================= +# GPU/Provider settings +# ============================================================================= +USE_GPU="${USE_GPU:-true}" +ATTENTION_TYPE="${ATTENTION_TYPE:-auto}" + +# Check for MIGraphX provider +if [ "$USE_GPU" = true ]; then + HAS_MIGRAPHX=$(python3 -c "import onnxruntime as ort; print('yes' if 'MIGraphXExecutionProvider' in ort.get_available_providers() else 'no')" 2>/dev/null || echo "no") + if [ "$HAS_MIGRAPHX" = "yes" ]; then + echo "MIGraphX EP detected - will optimize for Flash Attention" + PROVIDER="MIGraphXExecutionProvider" + else + echo "MIGraphX not available, using CPU optimization" + USE_GPU=false + PROVIDER="CPUExecutionProvider" + fi +else + PROVIDER="CPUExecutionProvider" +fi + +# ============================================================================= +# Run optimizer with FusionOptions for efficient attention +# ============================================================================= +echo "" +echo "Running ONNX Runtime transformer optimizer..." +echo " Enabling attention fusion for MIGraphX Flash Attention support" +echo "" + +# Export variables for Python +export INPUT_FILE OUTPUT_FILE MODEL_TYPE NUM_HEADS HIDDEN_SIZE NUM_KV_HEADS +export OPT_LEVEL SKIP_FP16 USE_GPU ATTENTION_TYPE + +# Run Python script +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +python3 "$SCRIPT_DIR/py/optimize_model.py" + +if [ $? -eq 0 ]; then + echo "" + ls -lh "$OUTPUT_FILE" +else + echo "❌ Optimization failed" + exit 1 +fi diff --git a/models/05_quantize_int4.sh b/models/05_quantize_int4.sh new file mode 100755 index 0000000..7ecd0f2 --- /dev/null +++ b/models/05_quantize_int4.sh @@ -0,0 +1,56 @@ +#!/bin/bash +# ============================================================================= +# 05_quantize_int4.sh - Quantize ONNX model to INT4 (4-bit weight quantization) +# ============================================================================= +# Usage: ./05_quantize_int4.sh [block_size] +# Example: ./05_quantize_int4.sh ./model.onnx ./model_int4.onnx 128 +# +# Requirements: +# - ONNX Runtime 1.20+ +# +# Block sizes: 32, 64, 128 (default), 256 +# - Smaller = better accuracy, larger model +# - Larger = smaller model, may lose some accuracy +# ============================================================================= + +set -e + +INPUT_FILE="${1:?Usage: $0 [block_size]}" +OUTPUT_FILE="${2:?Usage: $0 [block_size]}" +BLOCK_SIZE="${3:-128}" + +if [ ! -f "$INPUT_FILE" ]; then + echo "Error: File not found: $INPUT_FILE" + exit 1 +fi + +INPUT_DIR=$(dirname "$INPUT_FILE") +INPUT_BASE=$(basename "$INPUT_FILE" .onnx) + +# Check for external data +EXTERNAL_DATA="$INPUT_DIR/${INPUT_BASE}.onnx.data" +EXTERNAL_DATA_ALT="$INPUT_DIR/${INPUT_BASE}.onnx_data" +HAS_EXTERNAL=false +if [ -f "$EXTERNAL_DATA" ] || [ -f "$EXTERNAL_DATA_ALT" ]; then + HAS_EXTERNAL=true +fi + +echo "==============================================" +echo "Quantize to INT4 (4-bit Weight Quantization)" +echo "==============================================" +echo "Input: $INPUT_FILE" +echo "Output: $OUTPUT_FILE" +echo "Block size: $BLOCK_SIZE" +echo "External: $HAS_EXTERNAL" +echo "==============================================" + +# Export variables for Python +export INPUT_FILE OUTPUT_FILE BLOCK_SIZE HAS_EXTERNAL + +# Run Python script +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +python3 "$SCRIPT_DIR/py/quantize_int4.py" + +echo "" +echo "Output files:" +ls -lh "$OUTPUT_FILE"* 2>/dev/null || echo "Check output directory for files" diff --git a/models/05_quantize_int8.sh b/models/05_quantize_int8.sh new file mode 100755 index 0000000..551badb --- /dev/null +++ b/models/05_quantize_int8.sh @@ -0,0 +1,32 @@ +#!/bin/bash +# ============================================================================= +# 05_quantize_int8.sh - Quantize ONNX model to INT8 (dynamic quantization) +# ============================================================================= +# Usage: ./05_quantize_int8.sh +# Example: ./05_quantize_int8.sh ./model.onnx ./model_int8.onnx +# ============================================================================= + +set -e + +INPUT_FILE="${1:?Usage: $0 }" +OUTPUT_FILE="${2:?Usage: $0 }" + +if [ ! -f "$INPUT_FILE" ]; then + echo "Error: File not found: $INPUT_FILE" + exit 1 +fi + +echo "==============================================" +echo "Quantize to INT8 (Dynamic)" +echo "==============================================" +echo "Input: $INPUT_FILE" +echo "Output: $OUTPUT_FILE" +echo "==============================================" + +# Export variables for Python +export INPUT_FILE OUTPUT_FILE + +# Run Python script +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +python3 "$SCRIPT_DIR/py/quantize_int8.py" + diff --git a/models/06_convert_fp16.sh b/models/06_convert_fp16.sh new file mode 100755 index 0000000..2fe0140 --- /dev/null +++ b/models/06_convert_fp16.sh @@ -0,0 +1,32 @@ +#!/bin/bash +# ============================================================================= +# 06_convert_fp16.sh - Convert ONNX model to FP16 +# ============================================================================= +# Usage: ./06_convert_fp16.sh +# Example: ./06_convert_fp16.sh ./model.onnx ./model_fp16.onnx +# ============================================================================= + +set -e + +INPUT_FILE="${1:?Usage: $0 }" +OUTPUT_FILE="${2:?Usage: $0 }" + +if [ ! -f "$INPUT_FILE" ]; then + echo "Error: File not found: $INPUT_FILE" + exit 1 +fi + +echo "==============================================" +echo "Convert to FP16" +echo "==============================================" +echo "Input: $INPUT_FILE" +echo "Output: $OUTPUT_FILE" +echo "==============================================" + +# Export variables for Python +export INPUT_FILE OUTPUT_FILE + +# Run Python script +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +python3 "$SCRIPT_DIR/py/convert_fp16.py" + diff --git a/models/08_benchmark_migraphx.sh b/models/08_benchmark_migraphx.sh new file mode 100755 index 0000000..d5c3129 --- /dev/null +++ b/models/08_benchmark_migraphx.sh @@ -0,0 +1,138 @@ +#!/bin/bash +# ============================================================================= +# 08_benchmark_migraphx.sh - Benchmark ONNX model with MIGraphX EP +# ============================================================================= +# Usage: ./08_benchmark_migraphx.sh [options] +# +# Benchmarks inference performance using ONNX Runtime with MIGraphX EP. +# Wraps benchmark_migraphx.py with shell-friendly interface. +# +# Options: +# -n, --iterations Number of benchmark iterations (default: 100) +# -w, --warmup Number of warmup iterations (default: 5) +# -s, --seq-length Input sequence length (new tokens, default: 1) +# -k, --kv-length KV cache length (context tokens, default: 0) +# --exhaustive Enable exhaustive tuning +# --offload-copy Use CPU memory during compilation +# --no-cache Disable model caching +# -v, --verbose Enable verbose logging +# -q, --quiet Minimal output, only show final results +# --help Show this help +# +# Environment Variables: +# ITERATIONS= Override default iterations +# WARMUP= Override default warmup +# SEQ_LENGTH= Override default sequence length +# KV_LENGTH= Override default KV cache length +# +# Examples: +# ./08_benchmark_migraphx.sh ./Llama3.1-8B-Instruct/onnx +# ./08_benchmark_migraphx.sh ./onnx -n 500 -s 1 -k 512 +# ./08_benchmark_migraphx.sh ./onnx --seq-length 128 --quiet +# ============================================================================= + +set -e +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +# Environment defaults +ITERATIONS="${ITERATIONS:-100}" +WARMUP="${WARMUP:-5}" +SEQ_LENGTH="${SEQ_LENGTH:-1}" +KV_LENGTH="${KV_LENGTH:-0}" + +# Parse arguments +POSITIONAL=() +EXHAUSTIVE=false +OFFLOAD_COPY=false +NO_CACHE=false +VERBOSE=false +QUIET=false + +while [[ $# -gt 0 ]]; do + case $1 in + -n|--iterations) + ITERATIONS="$2" + shift 2 + ;; + -w|--warmup) + WARMUP="$2" + shift 2 + ;; + -s|--seq-length) + SEQ_LENGTH="$2" + shift 2 + ;; + -k|--kv-length) + KV_LENGTH="$2" + shift 2 + ;; + --exhaustive) + EXHAUSTIVE=true + shift + ;; + --offload-copy) + OFFLOAD_COPY=true + shift + ;; + --no-cache) + NO_CACHE=true + shift + ;; + -v|--verbose) + VERBOSE=true + shift + ;; + -q|--quiet) + QUIET=true + shift + ;; + --help|-h) + head -35 "$0" | tail -32 + exit 0 + ;; + -*) + echo "Unknown option: $1" + exit 1 + ;; + *) + POSITIONAL+=("$1") + shift + ;; + esac +done + +set -- "${POSITIONAL[@]}" +MODEL_DIR="${1:?Usage: $0 [options]}" + +if [ ! -d "$MODEL_DIR" ]; then + # Check if it's a direct path to model.onnx + if [ -f "$MODEL_DIR" ]; then + MODEL_DIR="$(dirname "$MODEL_DIR")" + else + echo "Error: Directory not found: $MODEL_DIR" + exit 1 + fi +fi + +# Verify benchmark script exists +BENCH_SCRIPT="$SCRIPT_DIR/benchmark_migraphx.py" +if [ ! -f "$BENCH_SCRIPT" ]; then + echo "Error: benchmark_migraphx.py not found in $SCRIPT_DIR" + exit 1 +fi + +# Build Python arguments +PYTHON_ARGS="$MODEL_DIR" +PYTHON_ARGS="$PYTHON_ARGS --iterations $ITERATIONS" +PYTHON_ARGS="$PYTHON_ARGS --warmup $WARMUP" +PYTHON_ARGS="$PYTHON_ARGS --seq-length $SEQ_LENGTH" +PYTHON_ARGS="$PYTHON_ARGS --kv-length $KV_LENGTH" + +[ "$EXHAUSTIVE" = true ] && PYTHON_ARGS="$PYTHON_ARGS --exhaustive-tune" +[ "$OFFLOAD_COPY" = true ] && PYTHON_ARGS="$PYTHON_ARGS --offload-copy" +[ "$NO_CACHE" = true ] && PYTHON_ARGS="$PYTHON_ARGS --no-cache" +[ "$VERBOSE" = true ] && PYTHON_ARGS="$PYTHON_ARGS --verbose" +[ "$QUIET" = true ] && PYTHON_ARGS="$PYTHON_ARGS --quiet" + +# Run benchmark +exec python3 "$BENCH_SCRIPT" $PYTHON_ARGS diff --git a/models/09_run_inference_test.sh b/models/09_run_inference_test.sh new file mode 100755 index 0000000..f4d4690 --- /dev/null +++ b/models/09_run_inference_test.sh @@ -0,0 +1,154 @@ +#!/bin/bash +# ============================================================================= +# 09_run_inference_test.sh - Test inference with ONNX Runtime +# ============================================================================= +# Usage: ./09_run_inference_test.sh [provider] [options] +# +# Runs text generation to verify the model works correctly. +# Uses autoregressive generation with growing KV cache. +# +# Providers: +# MIGraphXExecutionProvider - AMD GPU with MIGraphX (default) +# ROCMExecutionProvider - AMD GPU with ROCm +# CUDAExecutionProvider - NVIDIA GPU +# CPUExecutionProvider - CPU fallback +# +# Options: +# --prompt Custom prompt (default: "What is 2+2?") +# --seq-length Static input sequence length (default: 256) +# Used for BOTH prefill and decode stages. +# Inputs are left-padded to this size. +# --temperature Sampling temperature (default: 0.0 = greedy) +# --verbose Enable verbose ORT logging +# --no-cache Disable model caching +# --exhaustive Enable exhaustive tuning +# --offload-copy Use CPU memory during compilation +# --help Show this help +# +# KV Cache Strategy (FULLY STATIC shapes): +# ALL shapes are FIXED to avoid MIGraphX recompilation and +# hipHostRegister failures on small arrays. +# +# Fixed shapes: +# BENCHMARK-COMPATIBLE SHAPES (the only shapes that work): +# input=(1, 1), attn=(1, 257), kv=(1, h, 256, d) +# Any other shape triggers hipHostRegister failures in MIGraphX. +# Prefill is slow (1 token/step) but decode matches benchmark speed. +# and copy it into the STATIC buffer at position filled_kv. +# +# Environment Variables: +# VERBOSE=true Enable verbose ORT + MIGraphX + HIP logging +# MIGRAPHX_FP16=1 Enable FP16 mode (default: disabled for pre-FP16 models) +# MIGRAPHX_SAVE_MODEL=1 Save compiled model +# +# Examples: +# ./09_run_inference_test.sh ./Llama3.1-8B-Instruct/onnx +# ./09_run_inference_test.sh ./onnx --prompt "Explain quantum computing" +# ./09_run_inference_test.sh ./onnx --seq-length 256 --temperature 0.7 +# ============================================================================= + +set -e +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +# Parse arguments +POSITIONAL=() +PROMPT="What is 2+2?" +SEQ_LENGTH=256 # Default bucket size (max_output = seq_length) +TEMPERATURE=0.0 +VERBOSE=false +NO_CACHE=false +EXHAUSTIVE=false +OFFLOAD_COPY=true # Default to offload for large models +MIGRAPHX_FP16="${MIGRAPHX_FP16:-0}" +MIGRAPHX_SAVE="${MIGRAPHX_SAVE_MODEL:-1}" + +while [[ $# -gt 0 ]]; do + case $1 in + --prompt) + PROMPT="$2" + shift 2 + ;; + --seq-length) + SEQ_LENGTH="$2" + shift 2 + ;; + --temperature) + TEMPERATURE="$2" + shift 2 + ;; + --verbose|-v) + VERBOSE=true + shift + ;; + --no-cache) + NO_CACHE=true + shift + ;; + --exhaustive) + EXHAUSTIVE=true + shift + ;; + --offload-copy) + OFFLOAD_COPY=true + shift + ;; + --no-offload-copy) + OFFLOAD_COPY=false + shift + ;; + --help|-h) + head -40 "$0" | tail -37 + exit 0 + ;; + -*) + echo "Unknown option: $1" + exit 1 + ;; + *) + POSITIONAL+=("$1") + shift + ;; + esac +done + +set -- "${POSITIONAL[@]}" +MODEL_DIR="${1:?Usage: $0 [provider] [options]}" +# MIGraphX provider - we use GPU OrtValues to avoid hipHostRegister issues +PROVIDER="${2:-MIGraphXExecutionProvider}" + +if [ ! -d "$MODEL_DIR" ]; then + echo "Error: Directory not found: $MODEL_DIR" + exit 1 +fi + +echo "==============================================" +echo "ONNX Runtime Text Generation Test" +echo "==============================================" +echo "Model dir: $MODEL_DIR" +echo "Provider: $PROVIDER" +echo "Prompt: \"$PROMPT\"" +echo "Max context: $SEQ_LENGTH tokens" +echo "Max output: $SEQ_LENGTH tokens" +echo "Temperature: $TEMPERATURE" +if [ "$PROVIDER" = "MIGraphXExecutionProvider" ]; then + echo "FP16 convert: $MIGRAPHX_FP16" + echo "Caching: $([ "$NO_CACHE" = true ] && echo 'disabled' || echo 'enabled')" + echo "Exhaustive: $EXHAUSTIVE" + echo "Offload: $OFFLOAD_COPY" +fi +echo "==============================================" + +# Auto-detect GPU target for ROCm +GPU_TARGET=$(rocminfo 2>/dev/null | grep -oP 'gfx\d+' | head -1 || echo "") +if [ -n "$GPU_TARGET" ]; then + if [[ "$GPU_TARGET" == gfx11* ]]; then + echo "Detected RDNA3 GPU: $GPU_TARGET" + fi +fi + +export MODEL_DIR PROVIDER PROMPT SEQ_LENGTH TEMPERATURE VERBOSE NO_CACHE EXHAUSTIVE OFFLOAD_COPY +export MIGRAPHX_FP16 MIGRAPHX_SAVE GPU_TARGET + +# Run Python script +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +python3 "$SCRIPT_DIR/py/run_inference_test.py" diff --git a/models/README.md b/models/README.md new file mode 100644 index 0000000..249a6a0 --- /dev/null +++ b/models/README.md @@ -0,0 +1,396 @@ +# ONNX Model Export and Optimization Scripts + +Scripts for exporting HuggingFace models to ONNX and running inference with MIGraphX/ROCm. + +## Requirements + +```bash +pip install torch transformers onnx onnxruntime onnxconverter-common +``` + +For MIGraphX support, ensure ROCm and MIGraphX are installed. + +## Quick Start + +### Full Pipeline (Recommended) + +```bash +# Make scripts executable +chmod +x *.sh + +# Export and test with MIGraphX (default GPU workflow) +./export_pipeline.sh /path/to/Llama3.1-8B-Instruct/hf ./Llama3.1-8B-Instruct/onnx + +# Pre-compile for common KV cache lengths (recommended for production) +./export_pipeline.sh ./model/hf ./model/onnx --precompile + +# Benchmark with specific context length +./export_pipeline.sh ./model/hf ./model/onnx --benchmark-only --kv-length 512 -n 500 + +# CPU target with optimization +./export_pipeline.sh ./model/hf ./model/onnx --cpu +``` + +## Default Settings (Optimized for Inference) + +All exports use these **inference-optimized defaults**: + +| Setting | Default | Description | +|---------|---------|-------------| +| **KV Cache** | ✅ ENABLED | Essential for efficient autoregressive generation | +| **Precision** | FP16 | Faster inference, lower memory | +| **Shapes** | Dynamic | Any batch/sequence length at runtime | +| **Caching** | ✅ ENABLED | MIGraphX compiled models cached in `migraphx_cache/` | + +```python +import onnxruntime as ort + +session = ort.InferenceSession( + 'model.onnx', + providers=['MIGraphXExecutionProvider'], + provider_options=[{ + 'device_id': 0, + 'migraphx_model_cache_dir': './migraphx_cache', + }] +) + +# Works with any sequence length +outputs = session.run(None, { + 'input_ids': input_ids, # shape: (batch, any_seq_len) + 'attention_mask': attention_mask, + # ... KV cache tensors ... +}) +``` + +## MIGraphX Shape Compilation + +**Important:** MIGraphX requires fixed shapes at compile time. Each unique `(seq_length, kv_length)` combination requires a separate compiled model (~3 min each for 8B models). + +### Automatic Caching + +The MIGraphX EP automatically caches compiled models. First inference with new shapes triggers compilation; subsequent runs use the cache. + +### Pre-compilation (Recommended for Production) + +Pre-compile common shapes to avoid runtime compilation delays: + +```bash +# Pre-compile with defaults (buckets 0-64K, seq-lengths 1,4,16,64) +python precompile_shapes.py ./Llama3.1-8B-Instruct/onnx + +# Custom buckets (smaller set for faster compilation) +python precompile_shapes.py ./onnx --buckets "0,512,1024,2048,4096,8192" + +# Custom sequence lengths +python precompile_shapes.py ./onnx --seq-lengths "1,4" --buckets "0,1024,4096,16384" +``` + +### Shape Bucketing Strategy + +For efficient production use, implement shape bucketing: + +```python +BUCKETS = [0, 128, 256, 512, 1024, 2048, 4096] + +def get_bucket(actual_kv_length): + """Find smallest bucket >= actual_length""" + for b in BUCKETS: + if b >= actual_kv_length: + return b + return BUCKETS[-1] + +# Pad KV cache to bucket size for cache hits +kv_length = get_bucket(actual_context_length) +``` + +## Workflows + +### GPU Target (Default) + +``` +Export (dynamic) → Validate → Test (MIGraphX EP) → Benchmark +``` + +```bash +./export_pipeline.sh ./model/hf ./model/onnx + +# With pre-compilation: +./export_pipeline.sh ./model/hf ./model/onnx --precompile + +# With custom benchmark settings: +./export_pipeline.sh ./model/hf ./model/onnx --seq-length 1 --kv-length 512 -n 500 +``` + +### CPU Target + +``` +Export (dynamic) → Validate → Optimize (FP16) → Test +``` + +```bash +./export_pipeline.sh ./model/hf ./model/onnx --cpu +``` + +### INT4/INT8 Quantization (CPU Only) + +```bash +# INT4 (~75% size reduction) +./export_pipeline.sh ./model/hf ./model/onnx --int4 + +# INT8 (~50% size reduction) +./export_pipeline.sh ./model/hf ./model/onnx --int8 +``` + +**Note**: Quantized models use operators MIGraphX doesn't support. Use CPU for quantized inference. + +## Benchmark Script + +The Python benchmark script provides detailed performance metrics: + +```bash +# Basic benchmark (100 iterations) +python benchmark_migraphx.py ./Llama3.1-8B-Instruct/onnx + +# With context (simulates decoding with 512-token history) +python benchmark_migraphx.py ./onnx --seq-length 1 --kv-length 512 + +# Extended benchmark with verbose logging +python benchmark_migraphx.py ./onnx -n 500 --verbose + +# Quick test with minimal output +python benchmark_migraphx.py ./onnx -n 50 --quiet +``` + +### Benchmark Options + +| Option | Default | Description | +|--------|---------|-------------| +| `-n, --iterations` | 100 | Number of benchmark iterations | +| `-w, --warmup` | 5 | Warmup iterations | +| `--seq-length` | 1 | Input sequence length (new tokens) | +| `--kv-length` | 0 | KV cache length (context tokens) | +| `--exhaustive-tune` | off | Exhaustive MIGraphX tuning | +| `--offload-copy` | off | Use CPU memory during compilation | +| `-v, --verbose` | off | Verbose ORT logging | +| `-q, --quiet` | off | Minimal output | +| `--no-cache` | off | Disable model caching | + +### Benchmark Output + +``` +============================================================ +Results +============================================================ +Iterations: 100 +Input tokens: 1 +Context tokens: 512 + +Average latency: 25.43ms +Std deviation: 1.23ms +Min latency: 23.12ms +Max latency: 31.45ms + +P50 latency: 25.21ms +P90 latency: 26.89ms +P99 latency: 29.12ms + +Throughput: 39.3 inferences/sec +Tokens/sec: 39.3 (output tokens) +``` + +## Scripts Reference + +| Script | Description | +|--------|-------------| +| `export_pipeline.sh` | **Main orchestration script** - runs full workflow | +| `01_export_model.sh` | Export HuggingFace model to ONNX (dynamic shapes) | +| `02_fix_external_data.sh` | Convert large models (>2GB) to external data format | +| `03_validate_model.sh` | Validate ONNX model structure | +| `04_optimize_model.sh` | Optimize for ONNX Runtime (attention fusion + FP16) | +| `05_quantize_int4.sh` | INT4 weight quantization | +| `05_quantize_int8.sh` | INT8 dynamic quantization | +| `06_convert_fp16.sh` | Convert weights to FP16 (standalone) | +| `precompile_shapes.py` | **Pre-compile MIGraphX for multiple shapes** | +| `08_benchmark_migraphx.sh` | Benchmark wrapper script | +| `09_run_inference_test.sh` | Quick inference test | +| `benchmark_migraphx.py` | **Python benchmark script** with detailed metrics | + +## Manual Step-by-Step + +```bash +chmod +x *.sh + +# 1. Export model to ONNX (FP16 + KV cache by default) +./01_export_model.sh /path/to/model/hf ./output + +# 2. Fix external data (if model > 2GB) +./02_fix_external_data.sh ./output/model.onnx + +# 3. Validate +./03_validate_model.sh ./output/model.onnx + +# 4. Test inference with MIGraphX +./09_run_inference_test.sh ./output MIGraphXExecutionProvider + +# 5. Pre-compile common shapes (uses defaults: buckets 0-64K, seq 1,4,16,64) +python precompile_shapes.py ./output + +# 6. Benchmark with context +python benchmark_migraphx.py ./output --seq-length 1 --kv-length 512 -n 100 +``` + +## Pipeline Options + +### Target Selection + +| Option | Description | +|--------|-------------| +| `--gpu` | Target GPU with MIGraphX (default) | +| `--cpu` | Target CPU | +| `--int4` | INT4 quantization (CPU only) | +| `--int8` | INT8 quantization (CPU only) | + +### Export Options + +| Option | Description | +|--------|-------------| +| `--opset ` | ONNX opset version (default: auto-detect, max 21) | +| `--no-kv-cache` | Disable KV cache (not recommended for inference) | +| `--fp32` | Export in FP32 instead of FP16 | + +### MIGraphX Options + +| Option | Description | +|--------|-------------| +| `--precompile` | Pre-compile for common KV cache lengths | +| `--exhaustive` | Enable exhaustive tuning (slower compile, faster inference) | +| `--offload-copy` | Use CPU memory during compilation | + +### Benchmarking Options + +| Option | Description | +|--------|-------------| +| `--seq-length ` | Input sequence length (default: 1) | +| `--kv-length ` | KV cache length / context (default: 0) | +| `--iterations ` | Benchmark iterations (default: 100) | +| `--skip-benchmark` | Skip benchmarking step | +| `--benchmark-only` | Only run benchmark (model must exist) | +| `--verbose` | Enable verbose logging | + +### Other Options + +| Option | Description | +|--------|-------------| +| `--dry-run` | Show what would be executed | +| `-h, --help` | Show help | + +## Environment Variables + +### MIGraphX Options + +| Variable | Default | Description | +|----------|---------|-------------| +| `MIGRAPHX_FP16` | `0` | Enable FP16 conversion (not needed for FP16 models) | + +### Benchmark Options + +| Variable | Default | Description | +|----------|---------|-------------| +| `SEQ_LENGTH` | `1` | Input sequence length | +| `KV_LENGTH` | `0` | KV cache length | +| `ITERATIONS` | `100` | Number of iterations | +| `WARMUP` | `5` | Warmup iterations | + +## Examples + +```bash +# Basic export and test (FP16 + KV cache enabled by default) +./export_pipeline.sh ./model/hf ./model/onnx + +# Export with pre-compilation for production +./export_pipeline.sh ./model/hf ./model/onnx --precompile + +# Benchmark with 512-token context (simulates decoding) +python benchmark_migraphx.py ./model/onnx --seq-length 1 --kv-length 512 -n 500 + +# Pre-compile with defaults (9 buckets × 4 seq-lengths = 36 shapes) +python precompile_shapes.py ./model/onnx + +# Quick inference test with verbose logging +./09_run_inference_test.sh ./model/onnx MIGraphXExecutionProvider --verbose + +# Export without KV cache (not recommended) +./01_export_model.sh ./model/hf ./output --no-kv-cache + +# Export in FP32 precision +./01_export_model.sh ./model/hf ./output --fp32 +``` + +## Supported Models (Auto-Detected) + +| Model | hidden_size | num_heads | num_kv_heads | num_layers | +|-------|-------------|-----------|--------------|------------| +| **Llama 3.2 1B** | 2048 | 32 | 8 | 16 | +| **Llama 3.2 3B** | 3072 | 24 | 8 | 28 | +| **Llama 3.1 8B** | 4096 | 32 | 8 | 32 | +| **Llama 3.1 70B** | 8192 | 64 | 8 | 80 | +| **Llama 3.1 405B** | 16384 | 128 | 8 | 126 | +| **Mistral 7B** | 4096 | 32 | 8 | 32 | + +## Execution Providers + +| Provider | Use Case | +|----------|----------| +| `MIGraphXExecutionProvider` | AMD GPUs with MIGraphX (recommended) | +| `ROCMExecutionProvider` | AMD GPUs with ROCm (deprecated in ORT 1.23+) | +| `CUDAExecutionProvider` | NVIDIA GPUs | +| `CPUExecutionProvider` | CPU fallback | + +## Troubleshooting + +### Model > 2GB protobuf error +```bash +./02_fix_external_data.sh ./output/model.onnx +``` + +### MIGraphX falls back to CPU +Check if all operators are supported: +```bash +python benchmark_migraphx.py ./model/onnx --verbose 2>&1 | grep -i "fallback\|cpu" +``` + +### Slow first inference +MIGraphX JIT-compiles on first run. Pre-compile to avoid: +```bash +python precompile_shapes.py ./model/onnx +``` + +### INT4 not working with MIGraphX +INT4 uses `GatherBlockQuantized` which MIGraphX doesn't support. Use CPU: +```bash +./09_run_inference_test.sh ./model/onnx CPUExecutionProvider +``` + +### Different KV lengths cause recompilation +MIGraphX requires fixed shapes. Use shape bucketing: +```bash +# Pre-compile all default shapes +python precompile_shapes.py ./model/onnx + +# Then pad actual KV cache to nearest bucket at runtime +``` + +### Out of memory during compilation +Use offload copy to use CPU memory during compilation: +```bash +python benchmark_migraphx.py ./model/onnx --offload-copy +# Or +./export_pipeline.sh ./model/hf ./model/onnx --offload-copy +``` + +### Verbose logging for debugging +```bash +python benchmark_migraphx.py ./model/onnx --verbose +# Or +./09_run_inference_test.sh ./model/onnx MIGraphXExecutionProvider --verbose +``` diff --git a/models/benchmark_migraphx.py b/models/benchmark_migraphx.py new file mode 100755 index 0000000..60ae778 --- /dev/null +++ b/models/benchmark_migraphx.py @@ -0,0 +1,245 @@ +#!/usr/bin/env python3 +"""MIGraphX benchmark script for ONNX models with KV cache.""" + +import argparse +import json +import os +import time +import numpy as np +import onnxruntime as ort + +# Log severity levels: +# 0 = VERBOSE (all messages) +# 1 = INFO +# 2 = WARNING (default, shows WARNING and above) +# 3 = ERROR +# 4 = FATAL + + +def detect_model_dtype(model_path): + """Detect if model uses FP16 or FP32 by checking input types.""" + import onnx + model = onnx.load(model_path, load_external_data=False) + + for inp in model.graph.input: + elem_type = inp.type.tensor_type.elem_type + # Check tensor inputs (skip int64 inputs like input_ids) + if elem_type == onnx.TensorProto.FLOAT16: + return np.float16 + elif elem_type == onnx.TensorProto.FLOAT: + return np.float32 + + # Default to float16 for modern models + return np.float16 + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark MIGraphX inference") + parser.add_argument("model_dir", help="Directory containing model.onnx and export_info.json") + parser.add_argument("--iterations", "-n", type=int, default=100, help="Number of benchmark iterations (default: 100)") + parser.add_argument("--warmup", "-w", type=int, default=5, help="Number of warmup iterations (default: 5)") + parser.add_argument("--seq-length", type=int, default=256, + help="Bucket size: prompt padded to this, KV cache = 2×this (default: 256)") + parser.add_argument("--no-cache", action="store_true", help="Disable model caching") + parser.add_argument("--convert-fp16", action="store_true", help="Force FP32->FP16 conversion (not needed if model is already FP16)") + parser.add_argument("--exhaustive-tune", action="store_true", help="Enable exhaustive tuning") + parser.add_argument("--offload-copy", action="store_true", help="Use CPU memory during compilation (reduces GPU memory usage)") + parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose logging (shows all ORT messages)") + parser.add_argument("--log-level", type=int, default=2, choices=[0, 1, 2, 3, 4], + help="Log severity level: 0=VERBOSE, 1=INFO, 2=WARNING (default), 3=ERROR, 4=FATAL") + parser.add_argument("--quiet", "-q", action="store_true", help="Only show final results (no per-iteration output)") + args = parser.parse_args() + + # Configure logging - must be done before creating any session + log_level = 0 if args.verbose else args.log_level + ort.set_default_logger_severity(log_level) + log_level_names = {0: "VERBOSE", 1: "INFO", 2: "WARNING", 3: "ERROR", 4: "FATAL"} + if not args.quiet: + print(f"ORT Log Level: {log_level_names.get(log_level, log_level)}") + + model_path = os.path.join(args.model_dir, "model.onnx") + info_path = os.path.join(args.model_dir, "export_info.json") + cache_path = os.path.join(args.model_dir, "migraphx_cache") + + if not os.path.exists(model_path): + print(f"Error: Model not found at {model_path}") + return 1 + + if not os.path.exists(info_path): + print(f"Error: Export info not found at {info_path}") + return 1 + + with open(info_path) as f: + info = json.load(f) + + num_layers = info["num_layers"] + num_kv_heads = info["num_kv_heads"] + head_dim = info["head_dim"] + + # Detect model dtype + model_dtype = detect_model_dtype(model_path) + dtype_name = "FP16" if model_dtype == np.float16 else "FP32" + + # Benchmark simulates decode step (seq_len=1, kv_len=bucket) + # This represents generating tokens after a prompt of `bucket` tokens + seq_len = 1 # Decode: one token at a time + kv_len = args.seq_length # KV cache = past context (the prompt) + + print("=" * 60) + print("MIGraphX Benchmark (Decode Phase)") + print("=" * 60) + print(f"Model: {model_path}") + print(f"Model dtype: {dtype_name}") + print(f"Layers: {num_layers}, KV Heads: {num_kv_heads}, Head Dim: {head_dim}") + print(f"Decode: seq_len=1, kv_len={kv_len}") + print(f" (simulates generating after {kv_len}-token prompt)") + print(f"Iterations: {args.iterations} (warmup: {args.warmup})") + print(f"Force FP16 conversion: {args.convert_fp16}") + print(f"Caching: {not args.no_cache}") + print(f"Exhaustive Tune: {args.exhaustive_tune}") + print(f"Offload Copy (CPU compile): {args.offload_copy}") + print() + + # Configure provider - only enable fp16 conversion if explicitly requested + # Models already in FP16 don't need conversion (saves memory) + provider_options = { + "device_id": "0", + "migraphx_fp16_enable": "1" if args.convert_fp16 else "0", + "migraphx_exhaustive_tune": "1" if args.exhaustive_tune else "0", + "migraphx_offload_copy": "1" if args.offload_copy else "0", + } + + if not args.no_cache: + os.makedirs(cache_path, exist_ok=True) + provider_options["migraphx_model_cache_dir"] = cache_path + print(f"Cache path: {cache_path}") + + # Create session - MIGraphX only, no CPU fallback + print("\nCreating session (MIGraphX only, no fallback)...") + t0 = time.time() + sess_options = ort.SessionOptions() + sess_options.log_severity_level = log_level + sess_options.log_verbosity_level = 10 if args.verbose else 0 # Higher = more verbose + + try: + session = ort.InferenceSession( + model_path, + sess_options, + providers=["MIGraphXExecutionProvider"], + provider_options=[provider_options], + ) + except Exception as e: + print(f"\nERROR: MIGraphX session creation failed!") + print(f"Exception: {e}") + print("\nThis means MIGraphX is not working properly.") + return 1 + + session_time = time.time() - t0 + print(f"Session created in {session_time:.2f}s") + + active_providers = session.get_providers() + print(f"Active providers: {active_providers}") + + if "MIGraphXExecutionProvider" not in active_providers: + print("\nERROR: MIGraphX is not active!") + return 1 + + # Build inputs for decode benchmark + # Only include inputs that the model actually expects + model_inputs = session.get_inputs() + input_names = [inp.name for inp in model_inputs] + + dtype = model_dtype + attn_len = seq_len + kv_len # attention covers current + past + + feed = {} + + if "input_ids" in input_names: + feed["input_ids"] = np.ones((1, seq_len), dtype=np.int64) + + if "attention_mask" in input_names: + feed["attention_mask"] = np.ones((1, attn_len), dtype=np.int64) + + if "position_ids" in input_names: + # Position for decode = kv_len (next position after past context) + feed["position_ids"] = np.array([[kv_len]], dtype=np.int64) + + # KV cache tensors (filled with random data to simulate real cache) + for i in range(num_layers): + key_name = f"past_key_values.{i}.key" + value_name = f"past_key_values.{i}.value" + if key_name in input_names: + feed[key_name] = np.random.randn(1, num_kv_heads, kv_len, head_dim).astype(dtype) + if value_name in input_names: + feed[value_name] = np.random.randn(1, num_kv_heads, kv_len, head_dim).astype(dtype) + + # Calculate memory footprint + total_bytes = sum(v.nbytes for v in feed.values()) + print(f"\nInputs: {len(feed)} tensors, {total_bytes / 1024 / 1024:.2f} MB") + + # Warmup + print(f"Running {args.warmup} warmup iterations...") + warmup_times = [] + for i in range(args.warmup): + t0 = time.time() + outputs = session.run(None, feed) + warmup_times.append(time.time() - t0) + if not args.quiet: + print(f" Warmup {i+1}: {warmup_times[-1]*1000:.2f}ms") + + print(f"Warmup avg: {np.mean(warmup_times)*1000:.2f}ms") + print(f"Output shape: {outputs[0].shape}") + + # Benchmark + print(f"\nBenchmarking ({args.iterations} iterations)...") + times = [] + + # Progress reporting + report_interval = max(1, args.iterations // 10) # Report ~10 times + + for i in range(args.iterations): + t0 = time.time() + outputs = session.run(None, feed) + elapsed = time.time() - t0 + times.append(elapsed) + + if not args.quiet and ((i + 1) % report_interval == 0 or i == 0): + avg_so_far = np.mean(times) * 1000 + print(f" [{i+1}/{args.iterations}] Current: {elapsed*1000:.2f}ms, Avg: {avg_so_far:.2f}ms") + + # Results + times_ms = np.array(times) * 1000 + avg_ms = np.mean(times_ms) + min_ms = np.min(times_ms) + max_ms = np.max(times_ms) + std_ms = np.std(times_ms) + p50_ms = np.percentile(times_ms, 50) + p90_ms = np.percentile(times_ms, 90) + p99_ms = np.percentile(times_ms, 99) + + print() + print("=" * 60) + print("Results (Decode Phase)") + print("=" * 60) + print(f"Iterations: {args.iterations}") + print(f"Decode shape: seq={seq_len}, kv={kv_len}") + print(f"Context length: {kv_len} tokens") + print() + print(f"Average latency: {avg_ms:.2f}ms") + print(f"Std deviation: {std_ms:.2f}ms") + print(f"Min latency: {min_ms:.2f}ms") + print(f"Max latency: {max_ms:.2f}ms") + print() + print(f"P50 latency: {p50_ms:.2f}ms") + print(f"P90 latency: {p90_ms:.2f}ms") + print(f"P99 latency: {p99_ms:.2f}ms") + print() + print(f"Throughput: {1000/avg_ms:.1f} inferences/sec") + print(f"Tokens/sec: {args.seq_length * 1000/avg_ms:.1f} (output tokens)") + print() + + return 0 + + +if __name__ == "__main__": + exit(main()) diff --git a/models/check_migraphx_support.sh b/models/check_migraphx_support.sh new file mode 100755 index 0000000..ac60b29 --- /dev/null +++ b/models/check_migraphx_support.sh @@ -0,0 +1,333 @@ +#!/bin/bash +# ============================================================================= +# check_migraphx_support.sh - Check MIGraphX compatibility and operator support +# ============================================================================= +# Usage: ./check_migraphx_support.sh [model.onnx] +# +# Without arguments: runs GPU and MIGraphX diagnostics only +# With model path: also checks operator support for the model +# ============================================================================= + +set -e + +MODEL_FILE="${1:-}" + +echo "==============================================" +echo "MIGraphX Compatibility Check" +echo "==============================================" + +# GPU Information +echo "" +echo "[1] GPU Information" +echo "----------------------------------------------" +GPU_TARGET=$(rocminfo 2>/dev/null | grep -oP 'gfx\d+' | head -1 || echo "unknown") +GPU_NAME=$(rocminfo 2>/dev/null | grep "Marketing Name:" | head -1 | cut -d: -f2 | xargs || echo "unknown") +echo "GPU Target: $GPU_TARGET" +echo "GPU Name: $GPU_NAME" + +# ROCm Version +echo "" +echo "[2] ROCm / MIGraphX Version" +echo "----------------------------------------------" +ROCM_VERSION=$(cat /opt/rocm/.info/version 2>/dev/null || echo "not found") +echo "ROCm: $ROCM_VERSION" + +if command -v migraphx-driver &> /dev/null; then + MIGRAPHX_VERSION=$(migraphx-driver --version 2>/dev/null | head -1 || echo "error") + echo "MIGraphX: $MIGRAPHX_VERSION" +else + echo "MIGraphX: migraphx-driver not found" +fi + +# ONNX Runtime +echo "" +echo "[3] ONNX Runtime" +echo "----------------------------------------------" +python3 -c " +import onnxruntime as ort +print(f'Version: {ort.__version__}') +print(f'Providers: {ort.get_available_providers()}') +has_migraphx = 'MIGraphXExecutionProvider' in ort.get_available_providers() +print(f'MIGraphX EP: {\"✓ Available\" if has_migraphx else \"✗ Not available\"}')" 2>/dev/null || echo "ONNX Runtime not installed" + +# Simple MIGraphX test +echo "" +echo "[4] MIGraphX Compilation Test" +echo "----------------------------------------------" +python3 << 'PYTEST' +import os +import sys + +try: + import onnxruntime as ort + import tempfile + import numpy as np + + # Create minimal ONNX model for testing (use opset 17 for max compatibility) + import onnx + from onnx import helper, TensorProto + + X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 4]) + Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 4]) + relu_node = helper.make_node('Relu', ['X'], ['Y']) + graph = helper.make_graph([relu_node], 'test', [X], [Y]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid('', 17)]) + model.ir_version = 8 # Compatible with older ONNX Runtime builds + + with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f: + onnx.save(model, f.name) + temp_path = f.name + + # Test MIGraphX + sess_options = ort.SessionOptions() + provider_options = {'device_id': 0, 'migraphx_fp16_enable': False} + + session = ort.InferenceSession( + temp_path, + sess_options, + providers=['MIGraphXExecutionProvider'], + provider_options=[provider_options] + ) + + # Run inference + x = np.random.randn(1, 4).astype(np.float32) + result = session.run(None, {'X': x}) + + os.unlink(temp_path) + + actual = session.get_providers() + if 'MIGraphXExecutionProvider' in actual: + print("✓ MIGraphX compilation: SUCCESS") + print("✓ MIGraphX inference: SUCCESS") + else: + print(f"⚠ Fell back to: {actual}") + +except Exception as e: + print(f"✗ MIGraphX test failed: {e}") + import traceback + traceback.print_exc() +PYTEST + +# Check for model file +if [ -z "$MODEL_FILE" ]; then + echo "" + echo "==============================================" + echo "Done (no model specified)" + echo "==============================================" + echo "" + echo "To check operator support for a model:" + echo " $0 " + exit 0 +fi + +if [ ! -f "$MODEL_FILE" ]; then + echo "" + echo "Error: File not found: $MODEL_FILE" + exit 1 +fi + +echo "" +echo "==============================================" +echo "Model Operator Support Check" +echo "==============================================" +echo "Model: $MODEL_FILE" + +# Method 1: Try to parse with migraphx-driver +echo "" +echo "Method 1: migraphx-driver parse test" +echo "----------------------------------------------" +if command -v migraphx-driver &> /dev/null; then + echo "Running: migraphx-driver read --onnx $MODEL_FILE" + migraphx-driver read --onnx "$MODEL_FILE" 2>&1 | head -100 || true +else + echo "migraphx-driver not found" +fi + +# Method 2: Check operators against known MIGraphX support list +echo "" +echo "Method 2: Operator analysis" +echo "----------------------------------------------" + +python3 << EOF +import onnx +import os + +model_path = "$MODEL_FILE" + +print(f"Loading model: {model_path}") +model = onnx.load(model_path, load_external_data=False) + +# Count operators +op_counts = {} +for node in model.graph.node: + op_counts[node.op_type] = op_counts.get(node.op_type, 0) + 1 + +print(f"\nModel has {len(model.graph.node)} nodes, {len(op_counts)} unique operators") + +# Known MIGraphX supported operators (as of MIGraphX 2.x) +# This list is approximate - check MIGraphX docs for exact support +MIGRAPHX_SUPPORTED = { + # Basic + 'Add', 'Sub', 'Mul', 'Div', 'Pow', 'Sqrt', 'Exp', 'Log', + 'Abs', 'Neg', 'Ceil', 'Floor', 'Round', + 'Relu', 'LeakyRelu', 'Elu', 'Selu', 'Sigmoid', 'Tanh', 'Softmax', 'LogSoftmax', + 'Clip', 'Min', 'Max', 'Sum', 'Mean', + # Reduction + 'ReduceSum', 'ReduceMean', 'ReduceMax', 'ReduceMin', 'ReduceProd', + 'ReduceL1', 'ReduceL2', 'ReduceLogSum', 'ReduceLogSumExp', + # Matrix + 'MatMul', 'Gemm', 'MatMulInteger', + # Convolution + 'Conv', 'ConvTranspose', 'AveragePool', 'MaxPool', 'GlobalAveragePool', 'GlobalMaxPool', + # Normalization + 'BatchNormalization', 'InstanceNormalization', 'LRN', + 'LayerNormalization', # Limited support + # Shape + 'Reshape', 'Flatten', 'Squeeze', 'Unsqueeze', 'Transpose', + 'Concat', 'Split', 'Slice', 'Gather', 'GatherElements', + 'Shape', 'Size', 'Tile', 'Expand', 'Pad', + # Cast/Convert + 'Cast', 'CastLike', + # Logic + 'Equal', 'Less', 'Greater', 'LessOrEqual', 'GreaterOrEqual', + 'And', 'Or', 'Not', 'Xor', 'Where', + # Other common + 'Identity', 'Dropout', 'Constant', 'ConstantOfShape', + 'Range', 'Einsum', + # Attention (limited) + 'Attention', 'MultiHeadAttention', +} + +# Operators with known issues in MIGraphX +MIGRAPHX_PROBLEMATIC = { + 'SimplifiedLayerNormalization', # May not be supported + 'RotaryEmbedding', # Custom op + 'GatherND', # Limited support + 'ScatterND', # Limited support + 'NonZero', # Dynamic output shape + 'Loop', 'If', 'Scan', # Control flow + 'LSTM', 'GRU', 'RNN', # Recurrent (limited) + 'Resize', # Some modes not supported + 'GridSample', # Limited +} + +print("\n" + "=" * 60) +print("OPERATOR SUPPORT ANALYSIS") +print("=" * 60) + +supported = {} +unsupported = {} +problematic = {} +unknown = {} + +for op, count in sorted(op_counts.items(), key=lambda x: -x[1]): + if op in MIGRAPHX_SUPPORTED: + supported[op] = count + elif op in MIGRAPHX_PROBLEMATIC: + problematic[op] = count + elif op.startswith('com.') or op.startswith('ai.') or 'Custom' in op: + unsupported[op] = count + else: + unknown[op] = count + +print(f"\n✅ SUPPORTED ({len(supported)} types, {sum(supported.values())} nodes):") +for op, count in sorted(supported.items(), key=lambda x: -x[1])[:15]: + print(f" {op}: {count}") +if len(supported) > 15: + print(f" ... and {len(supported) - 15} more") + +if problematic: + print(f"\n⚠️ PROBLEMATIC ({len(problematic)} types, {sum(problematic.values())} nodes):") + for op, count in sorted(problematic.items(), key=lambda x: -x[1]): + print(f" {op}: {count}") + +if unsupported: + print(f"\n❌ UNSUPPORTED ({len(unsupported)} types, {sum(unsupported.values())} nodes):") + for op, count in sorted(unsupported.items(), key=lambda x: -x[1]): + print(f" {op}: {count}") + +if unknown: + print(f"\n❓ UNKNOWN STATUS ({len(unknown)} types, {sum(unknown.values())} nodes):") + for op, count in sorted(unknown.items(), key=lambda x: -x[1]): + print(f" {op}: {count}") + +# Check for dynamic shapes (problematic for MIGraphX) +print("\n" + "=" * 60) +print("DYNAMIC SHAPE ANALYSIS") +print("=" * 60) + +dynamic_inputs = [] +for inp in model.graph.input: + shape = [] + if inp.type.tensor_type.shape.dim: + for dim in inp.type.tensor_type.shape.dim: + if dim.dim_param: + shape.append(dim.dim_param) + elif dim.dim_value: + shape.append(dim.dim_value) + else: + shape.append('?') + if any(isinstance(s, str) for s in shape): + dynamic_inputs.append((inp.name, shape)) + +if dynamic_inputs: + print("⚠️ Model has dynamic input shapes:") + for name, shape in dynamic_inputs: + print(f" {name}: {shape}") + print("\n MIGraphX requires fixed shapes. Dynamic shapes may cause issues.") +else: + print("✅ All inputs have fixed shapes") + +# Check data types +print("\n" + "=" * 60) +print("DATA TYPE ANALYSIS") +print("=" * 60) + +dtype_map = { + 1: 'float32', 2: 'uint8', 3: 'int8', 4: 'uint16', 5: 'int16', + 6: 'int32', 7: 'int64', 9: 'bool', 10: 'float16', 11: 'double', + 12: 'uint32', 13: 'uint64', 14: 'complex64', 15: 'complex128', + 16: 'bfloat16' +} + +initializer_dtypes = {} +for init in model.graph.initializer: + dtype = dtype_map.get(init.data_type, f'unknown({init.data_type})') + initializer_dtypes[dtype] = initializer_dtypes.get(dtype, 0) + 1 + +print("Initializer (weight) data types:") +for dtype, count in sorted(initializer_dtypes.items(), key=lambda x: -x[1]): + print(f" {dtype}: {count}") + +if 'float16' in initializer_dtypes: + print("\n⚠️ Model has FP16 weights - ensure MIGraphX FP16 mode is enabled") + +print("\n" + "=" * 60) +print("SUMMARY") +print("=" * 60) + +total_nodes = len(model.graph.node) +supported_nodes = sum(supported.values()) +problematic_nodes = sum(problematic.values()) +unsupported_nodes = sum(unsupported.values()) +unknown_nodes = sum(unknown.values()) + +print(f"Total nodes: {total_nodes}") +print(f"Likely supported: {supported_nodes} ({100*supported_nodes/total_nodes:.1f}%)") +print(f"Potentially problematic: {problematic_nodes} ({100*problematic_nodes/total_nodes:.1f}%)") +print(f"Likely unsupported: {unsupported_nodes} ({100*unsupported_nodes/total_nodes:.1f}%)") +print(f"Unknown: {unknown_nodes} ({100*unknown_nodes/total_nodes:.1f}%)") + +if problematic_nodes > 0 or unsupported_nodes > 0 or unknown_nodes > total_nodes * 0.1: + print("\n⚠️ This model may have compatibility issues with MIGraphX") + print(" Try:") + print(" 1. Check if operators are supported in your MIGraphX version") + print(" 2. Use CPU provider for testing: CPUExecutionProvider") + print(" 3. File an issue with MIGraphX for unsupported operators") +EOF + +echo "" +echo "==============================================" +echo "Done" +echo "==============================================" + diff --git a/models/export_pipeline.sh b/models/export_pipeline.sh new file mode 100755 index 0000000..5954279 --- /dev/null +++ b/models/export_pipeline.sh @@ -0,0 +1,448 @@ +#!/bin/bash +# ============================================================================= +# export_pipeline.sh - ONNX Export and Inference Pipeline +# ============================================================================= +# Usage: ./export_pipeline.sh [options] +# +# Workflows: +# GPU (default): Export → Validate → Test (MIGraphX EP) → Benchmark +# CPU (--cpu): Export → Validate → Optimize (FP16) → Test +# INT4 (--int4): Export → Validate → INT4 Quantize → Optimize → Test +# INT8 (--int8): Export → Validate → INT8 Quantize → Optimize → Test +# +# Defaults (optimized for inference): +# - KV cache: ENABLED (essential for autoregressive generation) +# - Precision: FP16 (faster, lower memory) +# - Shapes: Dynamic (any batch/sequence length) +# +# Options: +# --gpu Target MIGraphX (default) +# --cpu Target ONNX Runtime CPU +# --int4 INT4 quantization (CPU only) +# --int8 INT8 quantization (CPU only) +# --opset ONNX opset version (default: auto-detect) +# --no-kv-cache Disable KV cache (not recommended) +# --fp32 Export in FP32 instead of FP16 +# --skip-benchmark Skip benchmark step +# --benchmark-only Only run benchmark (model must exist) +# --precompile Pre-compile MIGraphX for common shapes +# --buckets Bucket sizes for precompile (default: 256) +# --seq-length Bucket size for testing (default: 256) +# KV cache = 2 × seq-length, max output = seq-length +# --iterations Benchmark iterations (default: 100) +# --exhaustive Enable exhaustive MIGraphX tuning +# --offload-copy Use CPU memory for MIGraphX compilation +# --verbose Enable verbose logging +# --dry-run Show commands without executing +# -h, --help Show this help +# +# Examples: +# ./export_pipeline.sh ./Llama3.1-8B/hf ./Llama3.1-8B/onnx +# ./export_pipeline.sh ./model/hf ./model/onnx --precompile +# ./export_pipeline.sh ./model/hf ./model/onnx --benchmark-only -n 500 +# ============================================================================= + +set -e +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +# Colors +RED='\033[0;31m'; GREEN='\033[0;32m'; YELLOW='\033[1;33m' +BLUE='\033[0;34m'; CYAN='\033[0;36m'; NC='\033[0m' + +print_header() { echo -e "\n${BLUE}══════════════════════════════════════════════════════════════════${NC}\n${BLUE} $1${NC}\n${BLUE}══════════════════════════════════════════════════════════════════${NC}"; } +print_step() { echo -e "${CYAN}▶ $1${NC}"; } +print_ok() { echo -e "${GREEN}✅ $1${NC}"; } +print_warn() { echo -e "${YELLOW}⚠️ $1${NC}"; } +print_err() { echo -e "${RED}❌ $1${NC}"; } + +show_help() { head -45 "$0" | tail -43; exit 0; } + +# ============================================================================= +# Defaults - OPTIMIZED FOR INFERENCE +# ============================================================================= +TARGET="gpu" +OPSET="" +NO_KV_CACHE=false +USE_FP32=false +SKIP_BENCHMARK=false +BENCHMARK_ONLY=false +PRECOMPILE=false +DRY_RUN=false +SEQ_LENGTH=256 # Bucket size (KV cache = 2 × this, max output = this) +BUCKETS="256" # Bucket sizes for precompile +ITERATIONS=100 +EXHAUSTIVE=false +OFFLOAD_COPY=true # Default: offload to CPU during compile +VERBOSE=false + +# ============================================================================= +# Parse Arguments +# ============================================================================= +POSITIONAL=() +while [[ $# -gt 0 ]]; do + case $1 in + --gpu) TARGET="gpu"; shift ;; + --cpu) TARGET="cpu"; shift ;; + --int4) TARGET="int4"; shift ;; + --int8) TARGET="int8"; shift ;; + --opset) OPSET="$2"; shift 2 ;; + --no-kv-cache) NO_KV_CACHE=true; shift ;; + --fp32) USE_FP32=true; shift ;; + --skip-benchmark) SKIP_BENCHMARK=true; shift ;; + --benchmark-only) BENCHMARK_ONLY=true; shift ;; + --precompile) PRECOMPILE=true; shift ;; + --buckets) BUCKETS="$2"; shift 2 ;; + --seq-length|-s) SEQ_LENGTH="$2"; shift 2 ;; + --iterations|-n) ITERATIONS="$2"; shift 2 ;; + --exhaustive) EXHAUSTIVE=true; shift ;; + --offload-copy) OFFLOAD_COPY=true; shift ;; + --no-offload-copy) OFFLOAD_COPY=false; shift ;; + --verbose|-v) VERBOSE=true; shift ;; + --dry-run) DRY_RUN=true; shift ;; + -h|--help) show_help ;; + -*) print_err "Unknown option: $1"; exit 1 ;; + *) POSITIONAL+=("$1"); shift ;; + esac +done +set -- "${POSITIONAL[@]}" + +if [ ${#POSITIONAL[@]} -lt 2 ]; then + print_err "Usage: $0 [options]" + exit 1 +fi + +MODEL_PATH="$1" +OUTPUT_DIR="$2" + +# ============================================================================= +# Validate +# ============================================================================= +if [ "$BENCHMARK_ONLY" = false ]; then + [ ! -d "$MODEL_PATH" ] && print_err "Model path not found: $MODEL_PATH" && exit 1 +fi + +for script in 01_export_model.sh 03_validate_model.sh 08_benchmark_migraphx.sh 09_run_inference_test.sh; do + [ ! -x "$SCRIPT_DIR/$script" ] && chmod +x "$SCRIPT_DIR/$script" +done + +mkdir -p "$OUTPUT_DIR" + +# ============================================================================= +# Auto-detect ONNX opset version if not specified +# ============================================================================= +if [ -z "$OPSET" ]; then + OPSET=$(python3 -c " +import onnx +latest = onnx.defs.onnx_opset_version() +print(min(latest, 21)) +" 2>/dev/null || echo "21") + OPSET_SOURCE="auto-detected" +else + OPSET_SOURCE="specified" +fi + +# ============================================================================= +# Configuration Summary +# ============================================================================= +print_header "Pipeline Configuration" +echo "" +echo " Model: $MODEL_PATH" +echo " Output: $OUTPUT_DIR" +echo " Target: $TARGET" +echo " Opset: $OPSET ($OPSET_SOURCE)" +echo " Precision: $([ "$USE_FP32" = true ] && echo 'FP32' || echo 'FP16 ✓')" +echo " KV cache: $([ "$NO_KV_CACHE" = true ] && echo 'disabled' || echo 'ENABLED ✓')" +echo " Shapes: dynamic" +echo "" +echo " Inference settings:" +echo " - Bucket size: $SEQ_LENGTH (prompt length / context length)" +echo " - Iterations: $ITERATIONS" +[ "$EXHAUSTIVE" = true ] && echo " - Exhaustive tuning: enabled" +[ "$OFFLOAD_COPY" = true ] && echo " - Offload copy: enabled (CPU memory during compile)" +[ "$PRECOMPILE" = true ] && echo " - Pre-compile buckets: $BUCKETS" +echo "" + +case $TARGET in + gpu) + if [ "$PRECOMPILE" = true ]; then + echo " Workflow: Export → Validate → Pre-compile → Test → Benchmark" + else + echo " Workflow: Export → Validate → Test (MIGraphX EP) → Benchmark" + fi + echo "" + echo " Optimized for inference:" + echo " - KV cache enabled for efficient autoregressive generation" + echo " - FP16 precision for speed and lower memory" + echo " - Pre-allocated KV cache (2 × bucket size)" + ;; + cpu) echo " Workflow: Export → Validate → Optimize → Test" ;; + int4) echo " Workflow: Export → Validate → INT4 Quantize → Optimize → Test" ;; + int8) echo " Workflow: Export → Validate → INT8 Quantize → Optimize → Test" ;; +esac +echo "" + +# ============================================================================= +# Helper +# ============================================================================= +run_cmd() { + local desc="$1"; shift + print_step "$desc" + if [ "$DRY_RUN" = true ]; then + echo " [DRY RUN] $*" + else + "$@" || { print_err "$desc failed"; exit 1; } + fi + print_ok "$desc" +} + +# ============================================================================= +# Build common benchmark arguments +# ============================================================================= +build_bench_args() { + local args="" + args="$args --seq-length $SEQ_LENGTH" + args="$args --iterations $ITERATIONS" + [ "$EXHAUSTIVE" = true ] && args="$args --exhaustive-tune" + [ "$OFFLOAD_COPY" = true ] && args="$args --offload-copy" + [ "$VERBOSE" = true ] && args="$args --verbose" + echo "$args" +} + +# ============================================================================= +# Skip to benchmark if requested +# ============================================================================= +if [ "$BENCHMARK_ONLY" = true ]; then + MODEL_ONNX="$OUTPUT_DIR/model.onnx" + [ ! -f "$MODEL_ONNX" ] && print_err "Model not found: $MODEL_ONNX" && exit 1 + + print_header "Benchmark Only Mode" + + BENCH_ARGS=$(build_bench_args) + run_cmd "Benchmark" python3 "$SCRIPT_DIR/benchmark_migraphx.py" "$OUTPUT_DIR" $BENCH_ARGS + + print_ok "Benchmark complete!" + exit 0 +fi + +# ============================================================================= +# Step 1: Export (Optimized for Inference) +# ============================================================================= +print_header "Step 1: Export Model" + +MODEL_ONNX="$OUTPUT_DIR/model.onnx" + +build_export_args() { + local args="" + [ -n "$OPSET" ] && args="$args --opset $OPSET" + [ "$NO_KV_CACHE" = true ] && args="$args --no-kv-cache" + [ "$USE_FP32" = true ] && args="$args --fp32" + echo "$args" +} + +if [ -f "$MODEL_ONNX" ]; then + print_warn "Model exists: $MODEL_ONNX" + read -p " Re-export? [y/N] " -n 1 -r + echo + if [[ $REPLY =~ ^[Yy]$ ]]; then + rm -f "$MODEL_ONNX" "$MODEL_ONNX.data" "${MODEL_ONNX}_data" + EXPORT_ARGS=$(build_export_args) + run_cmd "Export to ONNX (FP16 + KV cache)" "$SCRIPT_DIR/01_export_model.sh" "$MODEL_PATH" "$OUTPUT_DIR" $EXPORT_ARGS + else + print_ok "Using existing model" + fi +else + EXPORT_ARGS=$(build_export_args) + run_cmd "Export to ONNX (FP16 + KV cache)" "$SCRIPT_DIR/01_export_model.sh" "$MODEL_PATH" "$OUTPUT_DIR" $EXPORT_ARGS +fi + +# ============================================================================= +# Step 2: Validate +# ============================================================================= +print_header "Step 2: Validate Model" +run_cmd "Validate ONNX" "$SCRIPT_DIR/03_validate_model.sh" "$MODEL_ONNX" + +# ============================================================================= +# Step 3+: Target-specific workflow +# ============================================================================= +case $TARGET in + # ========================================================================= + # GPU: ONNX Runtime with MIGraphXExecutionProvider + # ========================================================================= + gpu) + STEP=3 + + # Pre-compile FIRST if requested (so test uses cached shapes) + if [ "$PRECOMPILE" = true ]; then + print_header "Step $STEP: Pre-compile MIGraphX (Cache Shapes)" + echo " Pre-compiling shapes for bucket: $BUCKETS" + echo " KV cache sizes: $(echo $BUCKETS | tr ',' '\n' | while read b; do echo -n "$((b*2)) "; done)" + echo "" + + if [ -f "$SCRIPT_DIR/precompile_shapes.py" ]; then + PRECOMPILE_ARGS="$OUTPUT_DIR --buckets $BUCKETS" + [ "$EXHAUSTIVE" = true ] && PRECOMPILE_ARGS="$PRECOMPILE_ARGS --exhaustive-tune" + [ "$OFFLOAD_COPY" = false ] && PRECOMPILE_ARGS="$PRECOMPILE_ARGS --no-offload-copy" + [ "$VERBOSE" = true ] && PRECOMPILE_ARGS="$PRECOMPILE_ARGS --verbose" + + run_cmd "Pre-compile shapes" python3 "$SCRIPT_DIR/precompile_shapes.py" $PRECOMPILE_ARGS + else + print_warn "precompile_shapes.py not found, skipping pre-compilation" + fi + STEP=$((STEP + 1)) + fi + + print_header "Step $STEP: Test Inference (MIGraphX EP)" + echo " Bucket size: $SEQ_LENGTH (prompt padded to this)" + echo " KV cache: $((SEQ_LENGTH * 2)) (pre-allocated)" + echo " Max output: $SEQ_LENGTH tokens" + echo "" + + # Build test args + TEST_ARGS="--seq-length $SEQ_LENGTH" + [ "$EXHAUSTIVE" = true ] && TEST_ARGS="$TEST_ARGS --exhaustive" + [ "$OFFLOAD_COPY" = true ] && TEST_ARGS="$TEST_ARGS --offload-copy" + [ "$VERBOSE" = true ] && TEST_ARGS="$TEST_ARGS --verbose" + + run_cmd "Test inference" "$SCRIPT_DIR/09_run_inference_test.sh" "$OUTPUT_DIR" "MIGraphXExecutionProvider" $TEST_ARGS + STEP=$((STEP + 1)) + + if [ "$SKIP_BENCHMARK" = false ]; then + print_header "Step $STEP: Benchmark" + + BENCH_ARGS=$(build_bench_args) + run_cmd "Benchmark" python3 "$SCRIPT_DIR/benchmark_migraphx.py" "$OUTPUT_DIR" $BENCH_ARGS + fi + + BEST_MODEL="$MODEL_ONNX" + ;; + + # ========================================================================= + # CPU: ONNX Runtime with FP16 optimization + # ========================================================================= + cpu) + print_header "Step 3: Optimize for ONNX Runtime CPU" + echo " Fusing attention patterns and converting to FP16..." + + MODEL_OPT="$OUTPUT_DIR/model_optimized.onnx" + + if [ -f "$MODEL_OPT" ]; then + print_warn "Optimized model exists: $MODEL_OPT" + read -p " Re-run optimization? [y/N] " -n 1 -r + echo + if [[ $REPLY =~ ^[Yy]$ ]]; then + rm -f "$MODEL_OPT" "$MODEL_OPT.data" "${MODEL_OPT}_data" + USE_GPU=false run_cmd "Optimize (attention fusion + FP16)" "$SCRIPT_DIR/04_optimize_model.sh" "$MODEL_ONNX" "$MODEL_OPT" "gpt_neox" + else + print_ok "Using existing optimized model" + fi + else + USE_GPU=false run_cmd "Optimize (attention fusion + FP16)" "$SCRIPT_DIR/04_optimize_model.sh" "$MODEL_ONNX" "$MODEL_OPT" "gpt_neox" + fi + + if [ ! -f "$MODEL_OPT" ]; then + print_err "Optimized model not found: $MODEL_OPT" + exit 1 + fi + + print_header "Step 4: Inference Test" + run_cmd "Test inference" "$SCRIPT_DIR/09_run_inference_test.sh" "$OUTPUT_DIR" "CPUExecutionProvider" --seq-length $SEQ_LENGTH + + BEST_MODEL="$MODEL_OPT" + ;; + + # ========================================================================= + # INT4: Quantize then optimize + # ========================================================================= + int4) + print_header "Step 3: INT4 Quantization" + + MODEL_INT4="$OUTPUT_DIR/model_int4.onnx" + run_cmd "Quantize to INT4" "$SCRIPT_DIR/05_quantize_int4.sh" "$MODEL_ONNX" "$MODEL_INT4" 128 + + print_header "Step 4: Optimize INT4 Model" + MODEL_OPT="$OUTPUT_DIR/model_int4_optimized.onnx" + SKIP_FP16=true run_cmd "Optimize (no FP16)" "$SCRIPT_DIR/04_optimize_model.sh" "$MODEL_INT4" "$MODEL_OPT" + + print_header "Step 5: Inference Test" + run_cmd "Test inference" "$SCRIPT_DIR/09_run_inference_test.sh" "$OUTPUT_DIR" "CPUExecutionProvider" --seq-length $SEQ_LENGTH + + BEST_MODEL="$MODEL_OPT" + ;; + + # ========================================================================= + # INT8: Quantize then optimize + # ========================================================================= + int8) + print_header "Step 3: INT8 Quantization" + + MODEL_INT8="$OUTPUT_DIR/model_int8.onnx" + run_cmd "Quantize to INT8" "$SCRIPT_DIR/05_quantize_int8.sh" "$MODEL_ONNX" "$MODEL_INT8" + + print_header "Step 4: Optimize INT8 Model" + MODEL_OPT="$OUTPUT_DIR/model_int8_optimized.onnx" + SKIP_FP16=true run_cmd "Optimize (no FP16)" "$SCRIPT_DIR/04_optimize_model.sh" "$MODEL_INT8" "$MODEL_OPT" + + print_header "Step 5: Inference Test" + run_cmd "Test inference" "$SCRIPT_DIR/09_run_inference_test.sh" "$OUTPUT_DIR" "CPUExecutionProvider" --seq-length $SEQ_LENGTH + + BEST_MODEL="$MODEL_OPT" + ;; +esac + +# ============================================================================= +# Summary +# ============================================================================= +print_header "Pipeline Complete" +echo "" +echo " Best model: $BEST_MODEL" +echo "" +echo " Output files:" +ls -lh "$OUTPUT_DIR"/*.onnx "$OUTPUT_DIR"/*.data 2>/dev/null | sed 's/^/ /' || true +echo "" + +# Show cache directory if present +if [ -d "$OUTPUT_DIR/migraphx_cache" ]; then + echo " MIGraphX cache:" + ls -lh "$OUTPUT_DIR/migraphx_cache"/*.mxr 2>/dev/null | head -5 | sed 's/^/ /' || echo " (empty)" + CACHE_COUNT=$(ls "$OUTPUT_DIR/migraphx_cache"/*.mxr 2>/dev/null | wc -l || echo "0") + [ "$CACHE_COUNT" -gt 5 ] && echo " ... and $((CACHE_COUNT - 5)) more" + echo "" +fi + +case $TARGET in + gpu) + echo " Usage with ONNX Runtime (Python):" + echo " ────────────────────────────────────────────────────────────" + echo " import onnxruntime as ort" + echo " " + echo " session = ort.InferenceSession(" + echo " '$BEST_MODEL'," + echo " providers=['MIGraphXExecutionProvider']," + echo " provider_options=[{" + echo " 'device_id': 0," + echo " 'migraphx_model_cache_dir': '$OUTPUT_DIR/migraphx_cache'," + echo " }]" + echo " )" + echo " " + echo " # Use pre-compiled bucket size for KV cache" + echo " # KV cache = 2 × bucket, max output = bucket" + echo " outputs = session.run(None, {" + echo " 'input_ids': input_ids, # (1, bucket_size)" + echo " 'attention_mask': attn_mask, # (1, bucket_size + kv_cache_size)" + echo " # ... KV cache tensors (1, heads, kv_cache_size, head_dim) ..." + echo " })" + echo " ────────────────────────────────────────────────────────────" + echo "" + echo " Quick test:" + echo " ./09_run_inference_test.sh $OUTPUT_DIR --seq-length 256" + echo "" + if [ "$PRECOMPILE" != true ]; then + echo " Pre-compile for production (recommended):" + echo " python precompile_shapes.py $OUTPUT_DIR --buckets '256,512,1024'" + fi + ;; + cpu|int4|int8) + echo " Usage: Load $BEST_MODEL with ONNX Runtime CPUExecutionProvider" + ;; +esac +echo "" diff --git a/models/patches/migraphx_memory_optimization.patch b/models/patches/migraphx_memory_optimization.patch new file mode 100644 index 0000000..fd7d131 --- /dev/null +++ b/models/patches/migraphx_memory_optimization.patch @@ -0,0 +1,239 @@ +diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +index a59347841b..c93eff8a1d 100644 +--- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc ++++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +@@ -151,6 +151,7 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv + model_cache_path_{info.model_cache_dir}, + t_{info.target_device.c_str()}, + exhaustive_tune_{info.exhaustive_tune}, ++ offload_copy_{info.offload_copy}, + metadef_id_generator_{ModelMetadefIdGenerator::Create()}, + external_alloc_{info.external_alloc}, + external_free_{info.external_free}, +@@ -179,6 +180,7 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv + GET_ENV_STRING(migraphx_env_vars::kModelCachePath, model_cache_path_); + GET_ENV_BOOL(migraphx_env_vars::kDumpModelOps, dump_model_ops_); + GET_ENV_BOOL(migraphx_env_vars::kExhaustiveTune, exhaustive_tune_); ++ GET_ENV_BOOL(migraphx_env_vars::kOffloadCopy, offload_copy_); + + // Verify configuration correctness and adjust accordingly. + +@@ -714,10 +716,20 @@ std::unique_ptr MIGraphXExecutionProvider::GetSubGraph(const st + int input_order = 0; + int output_order = 0; + ++ // Collect initializers separately - MIGraphX embeds them as constants in the compiled model ++ // so ORT doesn't need to allocate VRAM for them ++ std::vector initializers; ++ + for (const auto& index : graph_nodes_index) { + sub_graph->Nodes().push_back(index); + const auto& node = graph.GetNode(index); + for (const auto& input : node->InputDefs()) { ++ // Check if this input is an initializer (weight/constant) ++ // If so, add to initializers list and skip fused_inputs ++ if (graph.IsConstantInitializer(input->Name(), true)) { ++ initializers.push_back(input->Name()); ++ continue; ++ } + const auto& it = fused_outputs.find(input); + if (it != fused_outputs.end()) { + fused_outputs.erase(it); +@@ -729,6 +741,11 @@ std::unique_ptr MIGraphXExecutionProvider::GetSubGraph(const st + } + + for (const auto& input : node->ImplicitInputDefs()) { ++ // Check if this input is an initializer (weight/constant) ++ if (graph.IsConstantInitializer(input->Name(), true)) { ++ initializers.push_back(input->Name()); ++ continue; ++ } + const auto& it = fused_outputs.find(input); + if (it != fused_outputs.end()) { + fused_outputs.erase(it); +@@ -835,6 +852,12 @@ std::unique_ptr MIGraphXExecutionProvider::GetSubGraph(const st + } + } + ++ // Mark initializers as constant - tells ORT not to allocate VRAM for them ++ // MIGraphX will embed these weights into the compiled model ++ for (const auto& initializer : initializers) { ++ meta_def->constant_initializers().push_back(initializer); ++ } ++ + for (const auto& output : output_names) { + meta_def->outputs().push_back(output); + } +@@ -1248,13 +1271,13 @@ void calibrate_and_quantize(migraphx::program& prog, + + void compile_program(migraphx::program& prog, + const migraphx::target& t, +- bool exhaustive_tune) { +- LOGS_DEFAULT(WARNING) << "Model Compile: Begin"; ++ bool exhaustive_tune, ++ bool offload_copy = false) { + migraphx::compile_options co; + co.set_fast_math(false); + co.set_exhaustive_tune_flag(exhaustive_tune); ++ co.set_offload_copy(offload_copy); + prog.compile(t, co); +- LOGS_DEFAULT(WARNING) << "Model Compile: Complete"; + } + + std::string to_hex(const uint64_t v) { +@@ -1320,6 +1343,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& + input_name_index[input_defs[i]->Name()] = i; + } + ++ // Create ONNX buffer from the fused subgraph + auto model = graph_body_viewer.CreateModel(*GetLogger()); + auto model_proto = model->ToProto(); + graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true); +@@ -1343,7 +1367,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& + + if (!no_input_shape) { + if (!load_precompiled_model(prog, model_cache_file)) { +- LOGS_DEFAULT(VERBOSE) << "No input shapes detected quantizing model"; ++ LOGS_DEFAULT(WARNING) << "MIGraphX: No cache found, compiling model (this may take a while)"; + #ifndef ENABLE_TRAINING_CORE + #ifdef HAVE_MIGRAPHX_API_ONNX_OPTIONS_SET_EXTERNAL_DATA_PATH + options.set_external_data_path(model_path_.parent_path().string()); +@@ -1354,8 +1378,10 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& + + calibrate_and_quantize(prog, t_, quant_params, fp16_enable_, bf16_enable_, int8_enable_, + fp8_enable_, int8_calibration_cache_available_, dynamic_range_map_); +- compile_program(prog, t_, exhaustive_tune_); ++ compile_program(prog, t_, exhaustive_tune_, offload_copy_); + save_compiled_model(prog, model_cache_file); ++ } else { ++ LOGS_DEFAULT(WARNING) << "MIGraphX: Loaded compiled model from cache"; + } + + auto prog_output_shapes = prog.get_output_shapes(); +@@ -1365,10 +1391,9 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& + } + } + +- // compile the program ++ // Store compiled program and ONNX buffer (needed for dynamic shape recompilation) + map_progs_[fused_node.Name()] = prog; +- +- map_onnx_string_[fused_node.Name()] = onnx_string_buffer; ++ map_onnx_string_[fused_node.Name()] = std::move(onnx_string_buffer); + map_input_index_[fused_node.Name()] = input_name_index; + map_no_input_shape_[fused_node.Name()] = no_input_shape; + NodeComputeInfo compute_info; +@@ -1428,6 +1453,26 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& + param_shapes = prog.get_parameter_shapes(); + auto prog_output_shapes = prog.get_output_shapes(); + ++ // Diagnostic: Compare what MIGraphX needs vs what ORT provides ++ // This helps identify if we can skip allocating weight tensors ++ size_t mgx_runtime_inputs = 0; ++ size_t mgx_outputs = 0; ++ size_t mgx_missing = 0; ++ for (auto&& name : param_shapes.names()) { ++ std::string name_str(name); ++ if (name_str.find("#output_") != std::string::npos) { ++ mgx_outputs++; ++ } else if (map_input_name_index.count(name_str) > 0) { ++ mgx_runtime_inputs++; ++ } else { ++ mgx_missing++; ++ LOGS_DEFAULT(WARNING) << "MIGraphX param not in ORT inputs: " << name_str; ++ } ++ } ++ LOGS_DEFAULT(WARNING) << "MIGraphX runtime: " << mgx_runtime_inputs << " inputs, " ++ << mgx_outputs << " outputs, " << mgx_missing << " missing. " ++ << "ORT provides: " << map_input_name_index.size() << " inputs"; ++ + // check whether input shapes match with shapes of program inputs + // migraphx::onnx_options cmp_options; + if (param_shapes.size() > 0) { +@@ -1497,7 +1542,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& + } + calibrate_and_quantize(prog, t, quant_params, fp16_enable, bf16_enable, int8_enable, + fp8_enable, int8_calibration_cache_available, map_dynamic_range); +- compile_program(prog, t, exhaustive_tune_); ++ compile_program(prog, t, exhaustive_tune_, offload_copy_); + save_compiled_model(prog, model_cache_file); + } + +diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +index 99f790b9f9..eafdcbf8c4 100644 +--- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h ++++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +@@ -33,6 +33,7 @@ constexpr auto kCachePath = "ORT_MIGRAPHX_CACHE_PATH"sv; + constexpr auto kINT8UseNativeMIGraphXCalibrationTable = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE"sv; + constexpr auto kExhaustiveTune = "ORT_MIGRAPHX_EXHAUSTIVE_TUNE"sv; + constexpr auto kModelCachePath = "ORT_MIGRAPHX_MODEL_CACHE_PATH"sv; ++constexpr auto kOffloadCopy = "ORT_MIGRAPHX_OFFLOAD_COPY"sv; + } // namespace migraphx_env_vars + + // Information to construct kernel function state. +@@ -98,6 +99,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { + {std::string{migraphx_provider_option::kInt8CalibTable}, MakeStringWithClassicLocale(int8_calibration_table_name_)}, + {std::string{migraphx_provider_option::kInt8UseNativeCalibTable}, MakeStringWithClassicLocale(int8_use_native_calibration_table_)}, + {std::string{migraphx_provider_option::kExhaustiveTune}, MakeStringWithClassicLocale(exhaustive_tune_)}, ++ {std::string{migraphx_provider_option::kOffloadCopy}, MakeStringWithClassicLocale(offload_copy_)}, + {std::string{migraphx_provider_option::kMemLimit}, MakeStringWithClassicLocale(mem_limit_)}, + {std::string{migraphx_provider_option::kArenaExtendStrategy}, EnumToName(arena_extend_strategy_mapping, arena_extend_strategy_)}, + {std::string{migraphx_provider_option::kGpuExternalAlloc}, MakeStringWithClassicLocale(external_alloc_)}, +@@ -125,6 +127,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { + hipStream_t stream_ = nullptr; + hipDeviceProp_t device_prop_{}; + bool exhaustive_tune_ = false; ++ bool offload_copy_ = false; + mutable std::filesystem::path model_path_{}; + size_t mem_limit_{std::numeric_limits::max()}; + ArenaExtendStrategy arena_extend_strategy_{ArenaExtendStrategy::kNextPowerOfTwo}; +diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc +index 33ef366eb1..e3b2e9056c 100644 +--- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc ++++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc +@@ -70,6 +70,7 @@ MIGraphXExecutionProviderInfo::MIGraphXExecutionProviderInfo(const ProviderOptio + .AddAssignmentToReference(migraphx_provider_option::kInt8UseNativeCalibTable, int8_use_native_calibration_table) + .AddAssignmentToReference(migraphx_provider_option::kInt8CalibTable, int8_calibration_table_name) + .AddAssignmentToReference(migraphx_provider_option::kExhaustiveTune, exhaustive_tune) ++ .AddAssignmentToReference(migraphx_provider_option::kOffloadCopy, offload_copy) + .AddAssignmentToReference(migraphx_provider_option::kMemLimit, mem_limit) + .AddAssignmentToEnumReference(migraphx_provider_option::kArenaExtendStrategy, arena_extend_strategy_mapping, arena_extend_strategy) + .Parse(options)); +@@ -97,6 +98,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions() const { + {std::string{migraphx_provider_option::kMemLimit}, MakeStringWithClassicLocale(mem_limit)}, + {std::string{migraphx_provider_option::kArenaExtendStrategy}, EnumToName(arena_extend_strategy_mapping, arena_extend_strategy)}, + {std::string{migraphx_provider_option::kExhaustiveTune}, MakeStringWithClassicLocale(exhaustive_tune)}, ++ {std::string{migraphx_provider_option::kOffloadCopy}, MakeStringWithClassicLocale(offload_copy)}, + {std::string{migraphx_provider_option::kGpuExternalAlloc}, MakeStringWithClassicLocale(external_alloc)}, + {std::string{migraphx_provider_option::kGpuExternalFree}, MakeStringWithClassicLocale(external_free)}, + {std::string{migraphx_provider_option::kGpuExternalEmptyCache}, MakeStringWithClassicLocale(external_empty_cache)}, +diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h +index 414254aaa2..cee458aa2f 100644 +--- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h ++++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h +@@ -34,6 +34,7 @@ constexpr auto kGpuExternalAlloc = "migraphx_external_alloc"sv; + constexpr auto kGpuExternalFree = "migraphx_external_free"sv; + constexpr auto kGpuExternalEmptyCache = "migraphx_external_empty_cache"sv; + constexpr auto kModelCacheDir = "migraphx_model_cache_dir"sv; ++constexpr auto kOffloadCopy = "migraphx_offload_copy"sv; + } // namespace migraphx_provider_option + + extern const EnumNameMapping arena_extend_strategy_mapping; +@@ -50,6 +51,7 @@ struct MIGraphXExecutionProviderInfo { + bool int8_use_native_calibration_table{false}; + std::filesystem::path model_cache_dir{}; + bool exhaustive_tune{false}; ++ bool offload_copy{false}; + + size_t mem_limit{std::numeric_limits::max()}; + ArenaExtendStrategy arena_extend_strategy{ArenaExtendStrategy::kNextPowerOfTwo}; +@@ -85,7 +87,8 @@ struct std::hash<::onnxruntime::MIGraphXExecutionProviderInfo> { + (static_cast(info.int8_enable) << 19) ^ + (static_cast(info.int8_use_native_calibration_table) << 20) ^ + (static_cast(info.exhaustive_tune) << 21) ^ +- (static_cast(info.bf16_enable) << 22); ++ (static_cast(info.bf16_enable) << 22) ^ ++ (static_cast(info.offload_copy) << 23); + + onnxruntime::HashCombine(data, value); + diff --git a/models/patches/migraphx_offload_copy.patch b/models/patches/migraphx_offload_copy.patch new file mode 100644 index 0000000..6312d94 --- /dev/null +++ b/models/patches/migraphx_offload_copy.patch @@ -0,0 +1,153 @@ +diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +index abc1234..def5678 100644 +--- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc ++++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +@@ -165,6 +165,7 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv + } + + GET_ENV_BOOL(migraphx_env_vars::kFP16Enable, fp16_enable_); ++ GET_ENV_BOOL(migraphx_env_vars::kOffloadCopy, offload_copy_); + + GET_ENV_BOOL(migraphx_env_vars::kBF16Enable, bf16_enable_); + +@@ -1246,12 +1247,15 @@ void calibrate_and_quantize(migraphx::program& prog, + + void compile_program(migraphx::program& prog, + const migraphx::target& t, +- bool exhaustive_tune) { ++ bool exhaustive_tune, ++ bool offload_copy) { + LOGS_DEFAULT(WARNING) << "Model Compile: Begin"; ++ LOGS_DEFAULT(WARNING) << " offload_copy: " << (offload_copy ? "true" : "false"); + migraphx::compile_options co; + co.set_fast_math(false); + co.set_exhaustive_tune_flag(exhaustive_tune); ++ co.set_offload_copy(offload_copy); + prog.compile(t, co); + LOGS_DEFAULT(WARNING) << "Model Compile: Complete"; + } +@@ -1354,7 +1358,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& + + calibrate_and_quantize(prog, t_, quant_params, fp16_enable_, bf16_enable_, int8_enable_, + fp8_enable_, int8_calibration_cache_available_, dynamic_range_map_); +- compile_program(prog, t_, exhaustive_tune_); ++ compile_program(prog, t_, exhaustive_tune_, offload_copy_); + save_compiled_model(prog, model_cache_file); + } + +@@ -1497,7 +1501,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& + } + calibrate_and_quantize(prog, t, quant_params, fp16_enable, bf16_enable, int8_enable, + fp8_enable, int8_calibration_cache_available, map_dynamic_range); +- compile_program(prog, t, exhaustive_tune_); ++ compile_program(prog, t, exhaustive_tune_, offload_copy_); + save_compiled_model(prog, model_cache_file); + } + +diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +index abc1234..def5678 100644 +--- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h ++++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +@@ -32,6 +32,7 @@ constexpr auto kCachePath = "ORT_MIGRAPHX_CACHE_PATH"sv; + constexpr auto kINT8UseNativeMIGraphXCalibrationTable = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE"sv; + constexpr auto kExhaustiveTune = "ORT_MIGRAPHX_EXHAUSTIVE_TUNE"sv; + constexpr auto kModelCachePath = "ORT_MIGRAPHX_MODEL_CACHE_PATH"sv; ++constexpr auto kOffloadCopy = "ORT_MIGRAPHX_OFFLOAD_COPY"sv; + } // namespace migraphx_env_vars + + // Information to construct kernel function state. +@@ -56,6 +57,7 @@ struct MIGraphXFuncState { + std::filesystem::path model_cache_dir; + bool dump_model_ops = false; + bool exhaustive_tune = false; ++ bool offload_copy = false; + }; + + // Logical device representation. +@@ -99,6 +101,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { + {std::string{migraphx_provider_option::kInt8CalibTable}, MakeStringWithClassicLocale(int8_calibration_table_name_)}, + {std::string{migraphx_provider_option::kInt8UseNativeCalibTable}, MakeStringWithClassicLocale(int8_use_native_calibration_table_)}, + {std::string{migraphx_provider_option::kExhaustiveTune}, MakeStringWithClassicLocale(exhaustive_tune_)}, ++ {std::string{migraphx_provider_option::kOffloadCopy}, MakeStringWithClassicLocale(offload_copy_)}, + {std::string{migraphx_provider_option::kMemLimit}, MakeStringWithClassicLocale(mem_limit_)}, + {std::string{migraphx_provider_option::kArenaExtendStrategy}, EnumToName(arena_extend_strategy_mapping, arena_extend_strategy_)}, + {std::string{migraphx_provider_option::kGpuExternalAlloc}, MakeStringWithClassicLocale(external_alloc_)}, +@@ -125,6 +128,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { + hipStream_t stream_ = nullptr; + hipDeviceProp_t device_prop_{}; + bool exhaustive_tune_ = false; ++ bool offload_copy_ = false; + mutable std::filesystem::path model_path_{}; + size_t mem_limit_{std::numeric_limits::max()}; + ArenaExtendStrategy arena_extend_strategy_{ArenaExtendStrategy::kNextPowerOfTwo}; +diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h +index abc1234..def5678 100644 +--- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h ++++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h +@@ -34,6 +34,7 @@ constexpr auto kGpuExternalAlloc = "migraphx_external_alloc"sv; + constexpr auto kGpuExternalFree = "migraphx_external_free"sv; + constexpr auto kGpuExternalEmptyCache = "migraphx_external_empty_cache"sv; + constexpr auto kModelCacheDir = "migraphx_model_cache_dir"sv; ++constexpr auto kOffloadCopy = "migraphx_offload_copy"sv; + } // namespace migraphx_provider_option + + extern const EnumNameMapping arena_extend_strategy_mapping; +@@ -50,6 +51,7 @@ struct MIGraphXExecutionProviderInfo { + std::string int8_calibration_table_name{}; + bool int8_use_native_calibration_table{false}; + std::filesystem::path model_cache_dir{}; ++ bool offload_copy{false}; + bool exhaustive_tune{false}; + + size_t mem_limit{std::numeric_limits::max()}; +@@ -85,7 +87,8 @@ struct std::hash<::onnxruntime::MIGraphXExecutionProviderInfo> { + (static_cast(info.fp16_enable) << 18) ^ + (static_cast(info.int8_enable) << 19) ^ + (static_cast(info.int8_use_native_calibration_table) << 20) ^ +- (static_cast(info.exhaustive_tune) << 21) ^ ++ (static_cast(info.exhaustive_tune) << 21) ^ ++ (static_cast(info.offload_copy) << 23) ^ + (static_cast(info.bf16_enable) << 22); + + onnxruntime::HashCombine(data, value); +diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc +index abc1234..def5678 100644 +--- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc ++++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc +@@ -70,6 +70,7 @@ MIGraphXExecutionProviderInfo::MIGraphXExecutionProviderInfo(const ProviderOptio + .AddAssignmentToReference(migraphx_provider_option::kInt8CalibTable, int8_calibration_table_name) + .AddAssignmentToReference(migraphx_provider_option::kExhaustiveTune, exhaustive_tune) + .AddAssignmentToReference(migraphx_provider_option::kMemLimit, mem_limit) ++ .AddAssignmentToReference(migraphx_provider_option::kOffloadCopy, offload_copy) + .AddAssignmentToEnumReference(migraphx_provider_option::kArenaExtendStrategy, arena_extend_strategy_mapping, arena_extend_strategy) + .Parse(options)); + } +@@ -81,6 +82,7 @@ MIGraphXExecutionProviderInfo::MIGraphXExecutionProviderInfo(const OrtMIGraphXPr + fp8_enable{options.migraphx_fp8_enable != 0}, + int8_enable{options.migraphx_int8_enable != 0}, + exhaustive_tune{options.migraphx_exhaustive_tune != 0}, ++ offload_copy{options.migraphx_offload_copy != 0}, + mem_limit{options.migraphx_mem_limit}, + arena_extend_strategy{options.migraphx_arena_extend_strategy} { + } +@@ -98,6 +100,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions() const { + {std::string{migraphx_provider_option::kMemLimit}, MakeStringWithClassicLocale(mem_limit)}, + {std::string{migraphx_provider_option::kArenaExtendStrategy}, EnumToName(arena_extend_strategy_mapping, arena_extend_strategy)}, + {std::string{migraphx_provider_option::kExhaustiveTune}, MakeStringWithClassicLocale(exhaustive_tune)}, ++ {std::string{migraphx_provider_option::kOffloadCopy}, MakeStringWithClassicLocale(offload_copy)}, + {std::string{migraphx_provider_option::kGpuExternalAlloc}, MakeStringWithClassicLocale(external_alloc)}, + {std::string{migraphx_provider_option::kGpuExternalFree}, MakeStringWithClassicLocale(external_free)}, + {std::string{migraphx_provider_option::kGpuExternalEmptyCache}, MakeStringWithClassicLocale(external_empty_cache)}, +diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h +index abc1234..def5678 100644 +--- a/include/onnxruntime/core/session/onnxruntime_c_api.h ++++ b/include/onnxruntime/core/session/onnxruntime_c_api.h +@@ -xxx,6 +xxx,7 @@ typedef struct OrtMIGraphXProviderOptions { + int migraphx_fp8_enable; + int migraphx_int8_enable; + int migraphx_exhaustive_tune; ++ int migraphx_offload_copy; // Enable offload copy (use CPU memory during compilation) + const char* migraphx_int8_calibration_table_name; + int migraphx_int8_use_native_calibration_table; + size_t migraphx_mem_limit; + diff --git a/models/precompile_shapes.py b/models/precompile_shapes.py new file mode 100755 index 0000000..0fbdcd9 --- /dev/null +++ b/models/precompile_shapes.py @@ -0,0 +1,496 @@ +#!/usr/bin/env python3 +"""Pre-compile MIGraphX models for common KV cache lengths. + +MIGraphX requires fixed shapes at compile time. This script pre-compiles +and caches models for common context lengths to avoid runtime recompilation. + +Each unique (seq_length, kv_length) combination is compiled once and cached. +Subsequent runs with cached shapes load instantly. + +IMPORTANT: Each shape is compiled in a completely SEPARATE SUBPROCESS to ensure +complete memory cleanup between compilations, preventing GPU OOM errors. + +Usage: + python precompile_shapes.py [options] + +Examples: + # Use defaults (decode + prefill shapes) + python precompile_shapes.py ./Llama3.1-8B-Instruct/onnx + + # Custom decode shapes only + python precompile_shapes.py ./onnx --buckets "512,1024,2048" --prefill-lengths "" + + # Custom prefill shapes for longer prompts + python precompile_shapes.py ./onnx --prefill-lengths "512,1024,2048,4096,8192" +""" + +import argparse +import json +import os +import subprocess +import sys +import time + + +def detect_model_dtype(model_path: str) -> str: + """Detect if model uses FP16 or FP32. Returns string for subprocess.""" + import onnx + model = onnx.load(model_path, load_external_data=False) + for inp in model.graph.input: + elem_type = inp.type.tensor_type.elem_type + if elem_type == onnx.TensorProto.FLOAT16: + return "float16" + elif elem_type == onnx.TensorProto.FLOAT: + return "float32" + return "float16" + + +def compile_in_subprocess( + model_path: str, + cache_path: str, + num_layers: int, + num_kv_heads: int, + head_dim: int, + dtype_str: str, + seq_len: int, + kv_len: int, + exhaustive_tune: bool, + offload_copy: bool, + verbose: bool, +) -> tuple[float, str]: + """ + Compile a single shape in a completely separate subprocess. + + This ensures ALL memory (GPU and CPU) is released when the subprocess exits. + + Returns: (time_taken, status: "compiled" | "cached" | "failed:reason") + """ + # Build the Python script to run in subprocess + script = f''' +import sys +import os +import time +import gc +import glob +import traceback + +# Import numpy and ort inside subprocess +import numpy as np +import onnxruntime as ort + +# Parameters passed from parent +model_path = {repr(model_path)} +cache_path = {repr(cache_path)} +num_layers = {num_layers} +num_kv_heads = {num_kv_heads} +head_dim = {head_dim} +dtype = np.{dtype_str} +seq_len = {seq_len} +kv_len = {kv_len} +verbose = {verbose} +exhaustive_tune = {exhaustive_tune} +offload_copy = {offload_copy} + +# Always use verbose logging for debugging +log_level = 0 # VERBOSE +ort.set_default_logger_severity(log_level) + +print(f"DEBUG: seq_len={{seq_len}}, kv_len={{kv_len}}", file=sys.stderr) +print(f"DEBUG: num_layers={{num_layers}}, num_kv_heads={{num_kv_heads}}, head_dim={{head_dim}}", file=sys.stderr) +print(f"DEBUG: dtype={{dtype}}", file=sys.stderr) + +# Session options +sess_options = ort.SessionOptions() +sess_options.log_severity_level = log_level +sess_options.log_verbosity_level = 10 # Maximum verbosity + +# Provider options +provider_options = {{ + "device_id": "0", + "migraphx_fp16_enable": "0", + "migraphx_model_cache_dir": cache_path, + "migraphx_exhaustive_tune": "1" if exhaustive_tune else "0", + "migraphx_offload_copy": "1" if offload_copy else "0", +}} +print(f"DEBUG: provider_options={{provider_options}}", file=sys.stderr) + +try: + # Create session + print("DEBUG: Creating session...", file=sys.stderr) + session = ort.InferenceSession( + model_path, + sess_options, + providers=["MIGraphXExecutionProvider"], + provider_options=[provider_options], + ) + print(f"DEBUG: Session created, providers={{session.get_providers()}}", file=sys.stderr) + + # Verify MIGraphX is active + if "MIGraphXExecutionProvider" not in session.get_providers(): + print("RESULT:failed:MIGraphX not active") + sys.exit(1) + + # Get model input/output info + model_inputs = session.get_inputs() + model_outputs = session.get_outputs() + input_names = [inp.name for inp in model_inputs] + + print(f"DEBUG: Model has {{len(model_inputs)}} inputs, {{len(model_outputs)}} outputs", file=sys.stderr) + print(f"DEBUG: First 5 input names: {{input_names[:5]}}", file=sys.stderr) + if len(input_names) > 5: + print(f"DEBUG: ... and {{len(input_names) - 5}} more inputs", file=sys.stderr) + + # Print expected shapes for first few inputs + for inp in model_inputs[:5]: + print(f"DEBUG: Input '{{inp.name}}': shape={{inp.shape}}, type={{inp.type}}", file=sys.stderr) + + # Total attention length = seq_len + kv_len + attn_len = seq_len + kv_len + + # Use simple numpy arrays like the working benchmark script + feed = {{}} + + if "input_ids" in input_names: + feed["input_ids"] = np.ones((1, seq_len), dtype=np.int64) + print(f"DEBUG: input_ids shape={{feed['input_ids'].shape}}", file=sys.stderr) + + if "attention_mask" in input_names: + feed["attention_mask"] = np.ones((1, attn_len), dtype=np.int64) + print(f"DEBUG: attention_mask shape={{feed['attention_mask'].shape}}", file=sys.stderr) + + if "position_ids" in input_names: + # Position for decode = kv_len (next position after past context) + feed["position_ids"] = np.array([[kv_len]], dtype=np.int64) if seq_len == 1 else np.arange(seq_len, dtype=np.int64).reshape(1, -1) + print(f"DEBUG: position_ids shape={{feed['position_ids'].shape}}", file=sys.stderr) + + # KV cache tensors (use random data like benchmark to simulate real cache) + kv_count = 0 + for i in range(num_layers): + key_name = f"past_key_values.{{i}}.key" + value_name = f"past_key_values.{{i}}.value" + if key_name in input_names: + feed[key_name] = np.random.randn(1, num_kv_heads, kv_len, head_dim).astype(dtype) + kv_count += 1 + if value_name in input_names: + feed[value_name] = np.random.randn(1, num_kv_heads, kv_len, head_dim).astype(dtype) + + print(f"DEBUG: Created {{kv_count}} KV cache pairs with kv_len={{kv_len}}", file=sys.stderr) + print(f"DEBUG: Total feed tensors: {{len(feed)}}", file=sys.stderr) + + # Verify all required inputs are provided + missing = [name for name in input_names if name not in feed] + if missing: + print(f"DEBUG: WARNING - Missing inputs: {{missing}}", file=sys.stderr) + + # Run inference to trigger compilation (use simple session.run like benchmark) + print("DEBUG: Running inference...", file=sys.stderr) + t0 = time.time() + try: + outputs = session.run(None, feed) + elapsed = time.time() - t0 + print(f"DEBUG: Inference completed in {{elapsed:.2f}}s", file=sys.stderr) + print(f"DEBUG: Output shapes: {{[o.shape for o in outputs[:3]]}}", file=sys.stderr) + except Exception as run_err: + error_msg = str(run_err) + # Check if this is the HIP registration error that happens after successful compilation + # The model is compiled and cached successfully, just running inference fails + if "register_on_gpu" in error_msg or "Failed to call function" in error_msg: + elapsed = time.time() - t0 + print(f"DEBUG: Inference failed with HIP error after {{elapsed:.2f}}s", file=sys.stderr) + print(f"DEBUG: This is a known MIGraphX issue - model IS compiled and cached successfully", file=sys.stderr) + # Check if cache file was created + cache_files_count = len(glob.glob(os.path.join(cache_path, "*.mxr"))) + # Model was likely compiled since inference was attempted + # Report as compiled (not failed) since the cache was written + print(f"DEBUG: Cache has {{cache_files_count}} .mxr files - treating as success", file=sys.stderr) + print(f"RESULT:compiled:{{elapsed:.1f}}") + sys.exit(0) # Exit successfully - compilation worked + else: + print(f"DEBUG: Inference FAILED: {{run_err}}", file=sys.stderr) + print(f"DEBUG: Traceback:", file=sys.stderr) + traceback.print_exc(file=sys.stderr) + raise + + # Determine if this was a compile or cache hit + if elapsed > 10: + print(f"RESULT:compiled:{{elapsed:.1f}}") + else: + print(f"RESULT:cached:{{elapsed*1000:.0f}}") + + # Explicit cleanup before exit + del session + del feed + del sess_options + gc.collect() + +except Exception as e: + print(f"RESULT:failed:{{str(e)[:200]}}") + sys.exit(1) +''' + + t0 = time.time() + + try: + # Run in completely separate subprocess + result = subprocess.run( + [sys.executable, '-c', script], + capture_output=True, + text=True, + timeout=900, # 15 minute timeout per shape + env={**os.environ, 'PYTHONUNBUFFERED': '1'}, + ) + + elapsed = time.time() - t0 + + # Parse output for RESULT line + output = result.stdout + result.stderr + for line in output.split('\n'): + if line.startswith('RESULT:'): + parts = line.split(':', 2) + if len(parts) >= 2: + status = parts[1] + detail = parts[2] if len(parts) > 2 else "" + + if status == "compiled": + return elapsed, "compiled" + elif status == "cached": + return elapsed, "cached" + else: + # Show debug output on failure + print("\n--- DEBUG OUTPUT (FAILED) ---", file=sys.stderr) + if result.stderr: + for dbg_line in result.stderr.split('\n'): + if dbg_line.strip(): + print(f" {dbg_line}", file=sys.stderr) + print("--- END DEBUG OUTPUT ---\n", file=sys.stderr) + return elapsed, f"failed:{detail}" + + # No RESULT line found + if result.returncode != 0: + # Show full debug output on failure + if verbose or True: # Always show on failure + print("\n--- DEBUG OUTPUT ---", file=sys.stderr) + if result.stderr: + for line in result.stderr.split('\n'): + if line.strip(): + print(f" {line}", file=sys.stderr) + print("--- END DEBUG OUTPUT ---\n", file=sys.stderr) + + # Get error message for status + err = result.stderr.strip() + if err: + # Find the most relevant error line + for line in reversed(err.split('\n')): + if 'FAILED' in line or 'Error' in line or 'error' in line: + return elapsed, f"failed:{line[:150]}" + return elapsed, f"failed:{err[-200:]}" + return elapsed, f"failed:exit code {result.returncode}" + + # Success but no status - assume compiled + return elapsed, "compiled" + + except subprocess.TimeoutExpired: + return time.time() - t0, "failed:timeout (15min)" + except Exception as e: + return time.time() - t0, f"failed:{e}" + + +def main(): + parser = argparse.ArgumentParser( + description="Pre-compile MIGraphX for multiple KV cache lengths", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__ + ) + parser.add_argument("model_dir", + help="Directory containing model.onnx and export_info.json") + parser.add_argument("--buckets", type=str, + default="256,512,1024,2048,4096,8192,16384,32768", + help="Comma-separated input bucket sizes. Prefill uses bucket, decode uses 2*bucket. " + "Default: 256,512,1024,2048,4096,8192,16384,32768") + parser.add_argument("--seq-lengths", type=str, default="1,256", + help="Comma-separated input sequence lengths for DECODE (default: 1)") + parser.add_argument("--prefill-lengths", type=str, default="", + help="Additional prefill lengths beyond buckets (default: none, use --buckets)") + parser.add_argument("--exhaustive-tune", action="store_true", + help="Enable exhaustive tuning (slower compile, faster runtime)") + parser.add_argument("--no-offload-copy", action="store_true", + help="Disable CPU memory offload during compilation (uses more GPU memory)") + parser.add_argument("--verbose", "-v", action="store_true", + help="Enable verbose ORT logging") + parser.add_argument("--quiet", "-q", action="store_true", + help="Minimal output") + args = parser.parse_args() + + # Parse shape lists (handle empty strings) + buckets = [int(x.strip()) for x in args.buckets.split(",") if x.strip()] + seq_lengths = [int(x.strip()) for x in args.seq_lengths.split(",") if x.strip()] + prefill_lengths = [int(x.strip()) for x in args.prefill_lengths.split(",") if x.strip()] + + model_path = os.path.join(args.model_dir, "model.onnx") + info_path = os.path.join(args.model_dir, "export_info.json") + cache_path = os.path.join(args.model_dir, "migraphx_cache") + + if not os.path.exists(model_path): + print(f"Error: Model not found at {model_path}") + return 1 + + if not os.path.exists(info_path): + print(f"Error: export_info.json not found at {info_path}") + return 1 + + with open(info_path) as f: + info = json.load(f) + + num_layers = info["num_layers"] + num_kv_heads = info["num_kv_heads"] + head_dim = info["head_dim"] + dtype_str = detect_model_dtype(model_path) + + os.makedirs(cache_path, exist_ok=True) + + # Build shape list for MIGraphX compilation: + # KV cache represents ACTUAL past context length (not pre-allocated buffer) + # + # For each bucket size B: + # 1. PREFILL: seq_len=B, kv_len=0 (process prompt, no past context) + # 2. DECODE: seq_len=1, kv_len=B (generate after prefill, past=B tokens) + # 3. DECODE: seq_len=1, kv_len=2*B (generate more, past=2*B tokens) + # + # This covers: prompt up to B tokens, then generate up to B more tokens + + shapes = [] + + # NOTE: kv_len=0 (true prefill with empty KV cache) is SKIPPED + # because HIP cannot register 0-element tensors. First inference + # will JIT compile for the actual prefill shape. + # + # We pre-compile DECODE shapes for fast generation after prefill. + + for bucket in sorted(buckets): + # DECODE: after prefill, kv_len = bucket (prompt is now in cache) + for seq_len in sorted(seq_lengths): + shapes.append(("decode", seq_len, bucket)) + + # DECODE: after generating more, kv_len = 2*bucket + for seq_len in sorted(seq_lengths): + shapes.append(("decode", seq_len, 2 * bucket)) + + # Add any additional prefill lengths (as decode kv_lengths) + for prompt_len in sorted(prefill_lengths): + if prompt_len not in buckets: + for seq_len in sorted(seq_lengths): + shapes.append(("decode", seq_len, prompt_len)) + shapes.append(("decode", seq_len, 2 * prompt_len)) + + total_shapes = len(shapes) + offload_copy = not args.no_offload_copy + + # Collect unique kv_lengths for display + kv_lengths = sorted(set(s[2] for s in shapes)) + + if not args.quiet: + print("=" * 60) + print("MIGraphX Shape Pre-compilation (DECODE only)") + print("=" * 60) + print(f"Model: {model_path}") + print(f"Model dtype: {dtype_str.upper()}") + print(f"Cache: {cache_path}") + print() + print(f"INPUT BUCKETS: {sorted(buckets)}") + print(f"KV CACHE SIZES: {kv_lengths}") + print() + print(f"DECODE shapes ({total_shapes}):") + print(f" seq_lengths: {seq_lengths}") + print(f" kv_lengths: {kv_lengths}") + print() + print(f"Total shapes: {total_shapes}") + print(f"Exhaustive tuning: {args.exhaustive_tune}") + print(f"Offload copy: {offload_copy} (CPU memory during compile)") + print() + print("STRATEGY: For bucket B, pre-compile decode shapes:") + print(" - Decode: seq=1, kv=B (after prefill)") + print(" - Decode: seq=1, kv=2*B (after generating B tokens)") + print() + print("NOTE: Prefill (kv=0) is NOT pre-compiled - HIP cannot register") + print(" 0-element tensors. First inference will JIT compile prefill.") + print() + print("NOTE: Each shape compiled in SEPARATE SUBPROCESS for memory isolation") + print() + + total_time = 0 + compiled = 0 + cached = 0 + failed = 0 + + for current, (phase, seq_len, kv_len) in enumerate(shapes, 1): + if not args.quiet: + print(f"[{current}/{total_shapes}] DECODE seq={seq_len}, kv={kv_len}...", + end=" ", flush=True) + + t, status = compile_in_subprocess( + model_path=model_path, + cache_path=cache_path, + num_layers=num_layers, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + dtype_str=dtype_str, + seq_len=seq_len, + kv_len=kv_len, + exhaustive_tune=args.exhaustive_tune, + offload_copy=offload_copy, + verbose=args.verbose, + ) + + total_time += t + + if status == "compiled": + compiled += 1 + if not args.quiet: + print(f"COMPILED in {t:.1f}s") + elif status == "cached": + cached += 1 + if not args.quiet: + print(f"cached ({t*1000:.0f}ms)") + elif status.startswith("failed:"): + failed += 1 + reason = status[7:] # Remove "failed:" prefix + if not args.quiet: + print(f"FAILED: {reason}") + + if not args.quiet: + print() + print("=" * 60) + print("Pre-compilation complete!") + print("=" * 60) + print(f"Total combinations: {total_shapes}") + print(f"Newly compiled: {compiled}") + print(f"Already cached: {cached}") + print(f"Failed: {failed}") + print(f"Total time: {total_time:.1f}s") + print(f"Cache location: {cache_path}") + print() + + # List cached files + try: + cache_files = [f for f in os.listdir(cache_path) if f.endswith('.mxr')] + if cache_files and not args.quiet: + print(f"Cached files ({len(cache_files)}):") + total_size = 0 + for f in sorted(cache_files)[:10]: + size_mb = os.path.getsize(os.path.join(cache_path, f)) / 1024 / 1024 + total_size += size_mb + print(f" {f} ({size_mb:.1f} MB)") + if len(cache_files) > 10: + # Calculate total size including remaining files + for f in sorted(cache_files)[10:]: + total_size += os.path.getsize(os.path.join(cache_path, f)) / 1024 / 1024 + print(f" ... and {len(cache_files) - 10} more") + print(f"\nTotal cache size: {total_size:.1f} MB") + except Exception: + pass + + return 0 if failed == 0 else 1 + + +if __name__ == "__main__": + exit(main()) diff --git a/models/py/convert_fp16.py b/models/py/convert_fp16.py new file mode 100644 index 0000000..4e68779 --- /dev/null +++ b/models/py/convert_fp16.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 +""" +convert_fp16.py - Convert ONNX model to FP16 +""" + +import onnx +import os +from onnxconverter_common import float16 +from pathlib import Path + + +def main(): + input_file = os.environ['INPUT_FILE'] + output_file = os.environ['OUTPUT_FILE'] + + print("Loading model...") + model = onnx.load(input_file, load_external_data=True) + + print("Converting to FP16...") + model_fp16 = float16.convert_float_to_float16( + model, + keep_io_types=True, # Keep inputs/outputs as FP32 for compatibility + ) + + print("Saving model...") + onnx.save(model_fp16, output_file) + + input_size = Path(input_file).stat().st_size / (1024**3) + output_size = Path(output_file).stat().st_size / (1024**3) + reduction = (1 - output_size / input_size) * 100 + + print(f"\n✅ Conversion complete!") + print(f" Input size: {input_size:.2f} GB") + print(f" Output size: {output_size:.2f} GB") + print(f" Reduction: {reduction:.1f}%") + + +if __name__ == '__main__': + main() diff --git a/models/py/export_model.py b/models/py/export_model.py new file mode 100644 index 0000000..dbf054e --- /dev/null +++ b/models/py/export_model.py @@ -0,0 +1,428 @@ +#!/usr/bin/env python3 +""" +export_model.py - Export HuggingFace model to ONNX for Inference + +Custom ONNX export with KV cache support using modern torch.export. +Does NOT require optimum library. + +IMPORTANT NOTES: +- Exports on CPU by default (set FORCE_CPU_EXPORT=false to use GPU) +- CPU export is RECOMMENDED for stability, especially with PyTorch nightly builds +- ONNX models are device-agnostic: CPU-exported models run fine on GPU/MIGraphX +- For PyTorch stable (non-nightly), either CPU or GPU export works +- For ROCm nightly builds, CPU export avoids potential issues +""" + +import sys +import os +import json +import gc +import torch +import onnx +from pathlib import Path +from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM +from transformers.cache_utils import DynamicCache, DynamicLayer + + +# ============================================================================ +# Export-friendly wrapper that takes flat tensor inputs +# Based on Optimum's approach: flatten KV cache to individual tensors +# ============================================================================ +class OnnxExportWrapper(torch.nn.Module): + """ + Wrapper for ONNX export that converts flat KV cache tensors to DynamicCache. + + Input signature: + - input_ids: (batch, seq_len) - token IDs + - attention_mask: (batch, seq_len) - attention mask + - past_seq_len: (256,) - padded tensor with past sequence length in first element (used to compute position_ids) + - past_kv_flat: tuple of 2*num_layers tensors, each (batch, num_kv_heads, past_seq, head_dim) + + Output signature: + - logits: (batch, seq_len, vocab_size) + - present_kv_flat: tuple of 2*num_layers tensors + + Note: position_ids is computed internally from past_seq_len[0] to avoid MIGraphX + hipHostRegister failures. The past_seq_len input is padded to 256 elements (2048 bytes) + to meet MIGraphX minimum buffer size requirements for hipHostRegister. + """ + + def __init__(self, model, num_layers, num_kv_heads, head_dim, dtype): + super().__init__() + self.model = model + self.num_layers = num_layers + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.dtype = dtype + + def forward(self, input_ids, attention_mask, past_seq_len_tensor, past_kv_flat): + """ + Forward pass with flat KV cache tensors as a tuple. + Computes position_ids internally to avoid hipHostRegister issues with small buffers. + + Args: + input_ids: (batch, seq_len) + attention_mask: (batch, seq_len) + past_seq_len_tensor: (256,) padded tensor with past sequence length in first element + past_kv_flat: tuple of KV cache tensors + """ + # Reconstruct DynamicCache from flat tensors + past_key_values = DynamicCache() + + if past_kv_flat is not None and len(past_kv_flat) > 0: + for i in range(self.num_layers): + key = past_kv_flat[2 * i] # (batch, num_kv_heads, past_seq, head_dim) + value = past_kv_flat[2 * i + 1] + past_key_values.update(key, value, i) + + # Compute position_ids internally from past_seq_len + # past_seq_len_tensor is padded to 256 elements to avoid hipHostRegister failures + # Extract the first element using pure tensor operations (no .item() to avoid CPU copy) + batch_size = input_ids.shape[0] + seq_len = input_ids.shape[1] + + # Extract scalar using tensor indexing (stays on device, no CPU transfer) + past_seq_len_scalar = past_seq_len_tensor[0:1] # (1,) tensor + + # Create position_ids: [past_seq_len, past_seq_len+1, ..., past_seq_len+seq_len-1] + # Use broadcasting to add past_seq_len to arange + position_ids = torch.arange(0, seq_len, dtype=torch.long, device=input_ids.device).unsqueeze(0) + position_ids = position_ids + past_seq_len_scalar # Broadcasting addition + position_ids = position_ids.expand(batch_size, -1) + + # Call model with computed position_ids + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=True, + return_dict=True, + ) + + logits = outputs.logits + present_kv = outputs.past_key_values + + # Flatten present_key_values for output + flat_outputs = [logits] + for i in range(len(present_kv.layers)): + layer = present_kv.layers[i] + flat_outputs.append(layer.keys) # (batch, num_kv_heads, total_seq, head_dim) + flat_outputs.append(layer.values) + + return tuple(flat_outputs) + + +def main(): + # Read from environment variables + model_path = os.environ['MODEL_PATH'] + output_dir = Path(os.environ['OUTPUT_DIR']) + opset_version = int(os.environ['OPSET_VERSION']) + use_fp16 = os.environ['USE_FP16'] == "true" + with_kv_cache = os.environ['WITH_KV_CACHE'] == "true" + + print(f"[1/6] Loading model configuration...") + config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + + # Extract model info + model_type = getattr(config, 'model_type', 'unknown') + hidden_size = getattr(config, 'hidden_size', 0) + num_heads = getattr(config, 'num_attention_heads', 0) + num_kv_heads = getattr(config, 'num_key_value_heads', num_heads) + num_layers = getattr(config, 'num_hidden_layers', 0) + vocab_size = getattr(config, 'vocab_size', 0) + max_position = getattr(config, 'max_position_embeddings', 4096) + head_dim = hidden_size // num_heads + + variants = { + 2048: "Llama 3.2 1B", + 3072: "Llama 3.2 3B", + 4096: "Llama 3.1 8B / Mistral 7B", + 8192: "Llama 3.1 70B", + 16384: "Llama 3.1 405B", + } + model_variant = variants.get(hidden_size, f"Unknown ({model_type})") + + print(f" Model: {model_variant}") + print(f" Type: {model_type}") + print(f" Hidden size: {hidden_size}") + print(f" Attention: {num_heads} heads, {num_kv_heads} KV heads") + print(f" Head dim: {head_dim}") + print(f" Layers: {num_layers}") + print(f" Vocab: {vocab_size}") + + print(f"\n[2/6] Loading tokenizer...") + tokenizer = AutoTokenizer.from_pretrained( + model_path, + trust_remote_code=True, + fix_mistral_regex=True, # Fix incorrect regex pattern in Llama/Mistral tokenizers + ) + tokenizer.save_pretrained(output_dir) + + print(f"\n[3/6] Loading model ({'FP16' if use_fp16 else 'FP32'})...") + dtype = torch.float16 if use_fp16 else torch.float32 + + # Device selection for export + # Note: ONNX export only traces the graph - optimization happens at inference time + # GPU export is faster for large models but may have stability issues with nightly builds + force_cpu_export = os.environ.get('FORCE_CPU_EXPORT', 'false') == 'true' + + if force_cpu_export: + device = "cpu" + print(f" Using CPU for export (stable)") + else: + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f" Using GPU for export (faster, uses ROCm)") + if device == "cuda": + print(f" Note: If export fails, try FORCE_CPU_EXPORT=true") + + model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=dtype, + trust_remote_code=True, + use_cache=with_kv_cache, + attn_implementation="eager", # Required for ONNX export + ) + model.eval() + model.to(device) + + print(f" Device: {device}") + print(f" Parameters: {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B") + + print(f"\n[4/6] Creating export wrapper...") + + wrapper = OnnxExportWrapper(model, num_layers, num_kv_heads, head_dim, dtype) + wrapper.eval() + + print(f" ✓ Export wrapper created") + print(f" KV cache: {num_layers} layers × 2 (key + value) = {2 * num_layers} tensors") + + print(f"\n[5/6] Preparing ONNX export...") + + # Create dummy inputs + batch_size = 1 + seq_len = 256 # Must be >= MIN_SEQ_LEN to satisfy Dim constraints + past_seq_len = 512 if with_kv_cache else 0 + + dummy_input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) + # Use seq_len for attention_mask to match dynamic_shapes (batch handling requires consistent dims) + dummy_attention_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device) + # past_seq_len as padded tensor (256 elements = 2048 bytes to avoid hipHostRegister failures) + # Only first element is used; rest is padding + dummy_past_seq_len = torch.zeros(256, dtype=torch.int64, device=device) + dummy_past_seq_len[0] = past_seq_len + + # Create KV cache inputs as a tuple + past_kv_list = [] + + # Use past_seq_len as scalar input instead of position_ids array + # This avoids hipHostRegister failures on small buffers + input_names = ["input_ids", "attention_mask", "past_seq_len"] + output_names = ["logits"] + + dynamic_axes = { + "input_ids": {0: "batch_size", 1: "sequence_length"}, + "attention_mask": {0: "batch_size", 1: "sequence_length"}, + # past_seq_len is a fixed-size padded tensor (256 elements) - no dynamic axes + "logits": {0: "batch_size", 1: "sequence_length"}, + } + + if with_kv_cache and past_seq_len > 0: + kv_shape = (batch_size, num_kv_heads, past_seq_len, head_dim) + print(f" KV cache input shape: {kv_shape}") + + for i in range(num_layers): + # Input past KV + key_name = f"past_key_values.{i}.key" + value_name = f"past_key_values.{i}.value" + input_names.extend([key_name, value_name]) + + past_kv_list.append(torch.randn(kv_shape, dtype=dtype, device=device)) + past_kv_list.append(torch.randn(kv_shape, dtype=dtype, device=device)) + + dynamic_axes[key_name] = {0: "batch_size", 2: "past_sequence_length"} + dynamic_axes[value_name] = {0: "batch_size", 2: "past_sequence_length"} + + # Output present KV + present_key_name = f"present.{i}.key" + present_value_name = f"present.{i}.value" + output_names.extend([present_key_name, present_value_name]) + + dynamic_axes[present_key_name] = {0: "batch_size", 2: "total_sequence_length"} + dynamic_axes[present_value_name] = {0: "batch_size", 2: "total_sequence_length"} + + past_kv_tuple = tuple(past_kv_list) if past_kv_list else () + dummy_inputs = (dummy_input_ids, dummy_attention_mask, dummy_past_seq_len, past_kv_tuple) + + print(f" Input tensors: {len(input_names)}") + print(f" Output tensors: {len(output_names)}") + print(f" past_seq_len (scalar): {past_seq_len} (position_ids computed internally)") + + # Verify wrapper works + print(f"\n Verifying wrapper forward pass...") + with torch.no_grad(): + test_output = wrapper(dummy_input_ids, dummy_attention_mask, dummy_past_seq_len, past_kv_tuple) + print(f" ✓ Forward pass successful") + print(f" Logits shape: {test_output[0].shape}") + if with_kv_cache: + print(f" Present KV[0].key shape: {test_output[1].shape}") + expected_kv_len = past_seq_len + seq_len + actual_kv_len = test_output[1].shape[2] + if actual_kv_len == expected_kv_len: + print(f" ✓ KV cache outputs ALL positions: {actual_kv_len} = {past_seq_len} + {seq_len}") + else: + print(f" ⚠ KV cache length mismatch: {actual_kv_len} (expected {expected_kv_len})") + + print(f"\n[6/6] Exporting to ONNX (opset {opset_version})...") + print(f" This may take several minutes for large models...") + + output_file = output_dir / "model.onnx" + + # Use dynamo=True for opset 21 with dynamic_shapes + from torch.export import Dim + + # CRITICAL: MIGraphX hipHostRegister bug - even 1024 bytes may fail + # HIP memory pool seems to have 2KB minimum allocation + # Testing with 256 elements = 2048 bytes + MIN_SEQ_LEN = 256 # Minimum sequence length to avoid hipHostRegister failure + + batch_dim = Dim("batch_size", min=1, max=64) + seq_dim = Dim("sequence_length", min=MIN_SEQ_LEN, max=4096) + past_seq_dim = Dim("past_sequence_length", min=0, max=131072) + + # Build dynamic_shapes matching input structure: (input_ids, attention_mask, position_ids, past_kv_tuple) + kv_dynamic_shapes = [] + if with_kv_cache and past_seq_len > 0: + for i in range(num_layers): + kv_dynamic_shapes.append({0: batch_dim, 2: past_seq_dim}) # key + kv_dynamic_shapes.append({0: batch_dim, 2: past_seq_dim}) # value + + # CRITICAL: All current sequence dimensions must use the same seq_dim + # past_seq_len is a scalar (no dynamic shape) + # position_ids is computed internally from past_seq_len to avoid hipHostRegister bug + dynamic_shapes_tuple = ( + {0: batch_dim, 1: seq_dim}, # input_ids + {0: batch_dim, 1: seq_dim}, # attention_mask (must match input_ids dim) + None, # past_seq_len (scalar, no dynamic shape) + tuple(kv_dynamic_shapes), # past_kv_flat tuple + ) + + # Export with dynamo=True (modern torch.export path) + # If this fails with nightly builds, try: dynamo=False with old export path + use_dynamo = os.environ.get('USE_DYNAMO', 'true') == 'true' + + if use_dynamo: + print(f" Using dynamo export (torch.export path, recommended)") + torch.onnx.export( + wrapper, + dummy_inputs, + str(output_file), + input_names=input_names, + output_names=output_names, + opset_version=opset_version, + dynamo=True, + dynamic_shapes=dynamic_shapes_tuple, + external_data=True, + report=True, + ) + print(f" ✓ ONNX export complete (dynamo, opset {opset_version})") + else: + print(f" Using legacy export (fallback for nightly issues)") + torch.onnx.export( + wrapper, + dummy_inputs, + str(output_file), + input_names=input_names, + output_names=output_names, + opset_version=opset_version, + dynamic_axes=dynamic_axes, + do_constant_folding=False, + external_data=True, + ) + print(f" ✓ ONNX export complete (legacy, opset {opset_version})") + + # Verify ONNX model + print(f"\n Verifying ONNX model...") + try: + onnx_model = onnx.load(str(output_file), load_external_data=False) + onnx.checker.check_model(onnx_model) + print(f" ✓ ONNX model structure is valid") + + print(f"\n ONNX Model Inputs ({len(onnx_model.graph.input)}):") + for inp in onnx_model.graph.input[:5]: + print(f" - {inp.name}") + if len(onnx_model.graph.input) > 5: + print(f" ... and {len(onnx_model.graph.input) - 5} more") + + print(f"\n ONNX Model Outputs ({len(onnx_model.graph.output)}):") + for out in onnx_model.graph.output[:5]: + print(f" - {out.name}") + if len(onnx_model.graph.output) > 5: + print(f" ... and {len(onnx_model.graph.output) - 5} more") + + except Exception as e: + print(f" ⚠ Could not verify: {e}") + + # Calculate sizes + data_files = list(output_dir.glob("model*.onnx*")) + total_size = sum(f.stat().st_size for f in data_files if f.exists()) + + # Save export info + export_info = { + "export_method": "torch.onnx.export with OnnxExportWrapper", + "shape_mode": "dynamic", + "precision": "fp16" if use_fp16 else "fp32", + "opset_version": opset_version, + "with_kv_cache": with_kv_cache, + "num_layers": num_layers, + "num_heads": num_heads, + "num_kv_heads": num_kv_heads, + "head_dim": head_dim, + "hidden_size": hidden_size, + "vocab_size": vocab_size, + "max_position_embeddings": max_position, + "model_variant": model_variant, + "model_type": model_type, + "input_names": input_names, + "output_names": output_names, + "dynamic_dims": { + "batch_size": "Variable batch size (1-64)", + "sequence_length": f"Current input sequence length ({MIN_SEQ_LEN}-4096, min=64 avoids MIGraphX hipHostRegister bug)", + "past_sequence_length": "Previous tokens in KV cache (0-131072)", + "total_sequence_length": "past_sequence_length + sequence_length", + }, + "kv_cache_info": { + "shape": f"(batch_size, {num_kv_heads}, sequence_length, {head_dim})", + "num_layers": num_layers, + "inputs_per_layer": 2, + "total_kv_inputs": 2 * num_layers, + } if with_kv_cache else None, + } + + with open(output_dir / "export_info.json", "w") as f: + json.dump(export_info, f, indent=2) + + # Clean up + del model, wrapper + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + print(f"\n{'='*60}") + print("✅ Export complete!") + print(f"{'='*60}") + print(f" Output directory: {output_dir}") + print(f" Total size: {total_size / (1024**3):.2f} GB") + print(f" position_ids: INCLUDED (enables full KV cache output)") + if with_kv_cache: + print(f" KV cache: {num_layers} layers × 2 (key+value)") + print(f" KV shape: (batch, {num_kv_heads}, seq_len, {head_dim})") + print(f"\n Dynamic dimensions:") + print(f" - batch_size: 1-64") + print(f" - sequence_length: {MIN_SEQ_LEN}-4096 (min={MIN_SEQ_LEN} avoids MIGraphX hipHostRegister bug)") + print(f" - past_sequence_length: 0-131072 (KV cache)") + print(f"{'='*60}") + + +if __name__ == '__main__': + main() diff --git a/models/py/fix_external_data.py b/models/py/fix_external_data.py new file mode 100644 index 0000000..b67a779 --- /dev/null +++ b/models/py/fix_external_data.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +""" +fix_external_data.py - Convert large ONNX model to use external data file + +Required for models > 2GB due to protobuf limits +""" + +import onnx +from onnx.external_data_helper import convert_model_to_external_data +from pathlib import Path +import os +import sys + +def main(): + model_file = Path(os.environ['MODEL_FILE']) + output_dir = model_file.parent + external_data_file = os.environ['EXTERNAL_DATA_FILE'] + file_size = int(os.environ['FILE_SIZE']) + + # For very large files (>2GB), we need special handling + if file_size > 2 * 1024 * 1024 * 1024: + print("Large model detected (>2GB). Using graph-only loading...") + print("This preserves external data references without loading weights into memory.") + + try: + # Load graph structure only (don't load external data into memory) + model = onnx.load(str(model_file), load_external_data=False) + + # Check if model already references external data + has_external_refs = False + for tensor in model.graph.initializer: + if tensor.HasField('data_location') and tensor.data_location == onnx.TensorProto.EXTERNAL: + has_external_refs = True + break + + if has_external_refs: + print("✅ Model already uses external data references.") + print(" External data file should contain the weights.") + + # Verify external data file exists + ext_path = output_dir / external_data_file + if ext_path.exists(): + ext_size = ext_path.stat().st_size + print(f" External data file: {ext_size / (1024**3):.2f} GB") + else: + print(f"⚠️ External data file not found: {ext_path}") + print(" Model may be corrupted or missing weight data.") + sys.exit(1) + else: + print("Model has embedded weights. Converting to external data format...") + + # Convert to external data + convert_model_to_external_data( + model, + all_tensors_to_one_file=True, + location=external_data_file, + size_threshold=1024, + convert_attribute=False + ) + + # Save the model with external data + print(f"Saving model with external data: {external_data_file}") + onnx.save_model( + model, + str(model_file), + save_as_external_data=True, + all_tensors_to_one_file=True, + location=external_data_file, + size_threshold=1024, + ) + print("✅ Done!") + + except Exception as e: + print(f"Error: {e}") + print("") + print("For models >2GB with embedded weights, try these alternatives:") + print("1. Re-export the model with external data from the start") + print("2. Use: python -m onnx.tools.update_inputs_outputs_dims") + sys.exit(1) + else: + print("Loading model (this may take a while for large models)...") + model = onnx.load(str(model_file), load_external_data=True) + + print(f"Saving with external data: {external_data_file}") + onnx.save_model( + model, + str(model_file), + save_as_external_data=True, + all_tensors_to_one_file=True, + location=external_data_file, + size_threshold=1024, + ) + + print("✅ Done!") + + +if __name__ == '__main__': + main() diff --git a/models/py/optimize_model.py b/models/py/optimize_model.py new file mode 100644 index 0000000..e20d6ce --- /dev/null +++ b/models/py/optimize_model.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python3 +""" +optimize_model.py - Optimize ONNX model for ONNX Runtime inference + +This script optimizes ONNX models for ONNX Runtime execution (CPU or GPU EP). +It fuses attention patterns into efficient operators (MultiHeadAttention/GQA) +which MIGraphX can then accelerate with Flash Attention kernels. +""" + +import os +import sys +from pathlib import Path + + +def main(): + # Input parameters + input_file = os.environ['INPUT_FILE'] + output_file = os.environ['OUTPUT_FILE'] + model_type = os.environ['MODEL_TYPE'] + num_heads = int(os.environ['NUM_HEADS']) + hidden_size = int(os.environ['HIDDEN_SIZE']) + num_kv_heads = int(os.environ['NUM_KV_HEADS']) + opt_level = int(os.environ['OPT_LEVEL']) + skip_fp16 = os.environ['SKIP_FP16'] == "true" + use_gpu = os.environ['USE_GPU'] == "true" + attention_type = os.environ['ATTENTION_TYPE'] + + input_path = Path(input_file) + output_path = Path(output_file) + input_dir = input_path.parent + + # Check for external data files + external_data_files = list(input_dir.glob(f"{input_path.stem}*.data")) + \ + list(input_dir.glob(f"{input_path.stem}*_data")) + has_external_data = len(external_data_files) > 0 + + # Calculate total model size + total_size = input_path.stat().st_size + for ext_file in external_data_files: + total_size += ext_file.stat().st_size + total_size_gb = total_size / (1024**3) + + # Force external data for large models + use_external = has_external_data or total_size_gb > 1.5 + + print(f"Configuration:") + print(f" Model type: {model_type}") + print(f" Num heads: {num_heads}") + print(f" Num KV heads: {num_kv_heads}") + print(f" Hidden size: {hidden_size}") + print(f" Model size: {total_size_gb:.2f} GB") + print(f" External data: {use_external}") + print(f" Use GPU: {use_gpu}") + print(f" FP16: {not skip_fp16}") + print(f" Opt level: {opt_level}") + print(f" Attention type: {attention_type}") + print() + + try: + from onnxruntime.transformers import optimizer + from onnxruntime.transformers.fusion_options import FusionOptions, AttentionOpType + + # Create FusionOptions with attention fusion enabled + fusion_options = FusionOptions(model_type) + + # Enable attention fusion for MIGraphX Flash Attention + fusion_options.enable_attention = True + fusion_options.use_multi_head_attention = True + fusion_options.enable_rotary_embeddings = True # Important for LLaMA RoPE + fusion_options.enable_shape_inference = True + + # Set attention operator type based on model architecture + if attention_type == "auto": + # Auto-detect: Use GQA if num_kv_heads < num_heads (LLaMA 3.x uses GQA) + if num_kv_heads < num_heads: + print(f" Detected GQA (KV heads {num_kv_heads} < Q heads {num_heads})") + fusion_options.attention_op_type = AttentionOpType.GroupQueryAttention + else: + print(f" Using MultiHeadAttention (standard MHA)") + fusion_options.attention_op_type = AttentionOpType.MultiHeadAttention + elif attention_type == "GroupQueryAttention": + fusion_options.attention_op_type = AttentionOpType.GroupQueryAttention + elif attention_type == "MultiHeadAttention": + fusion_options.attention_op_type = AttentionOpType.MultiHeadAttention + elif attention_type == "PagedAttention": + fusion_options.attention_op_type = AttentionOpType.PagedAttention + else: + fusion_options.attention_op_type = AttentionOpType.Attention + + print(f" Attention op: {fusion_options.attention_op_type}") + print() + + # Run optimizer + print("Optimizing model...") + print(" (This may take several minutes for large models)") + optimized_model = optimizer.optimize_model( + input=input_file, + model_type=model_type, + num_heads=num_heads, + hidden_size=hidden_size, + optimization_options=fusion_options, + opt_level=opt_level, + use_gpu=use_gpu, + only_onnxruntime=True, # Use only ONNX Runtime optimizations + ) + + # Convert to FP16 if enabled (skip symbolic inference for large models) + if not skip_fp16: + print("Converting to FP16...") + try: + optimized_model.convert_float_to_float16( + keep_io_types=True, # Keep input/output as FP32 for compatibility + use_symbolic_shape_infer=(total_size_gb < 2.0), # Skip for large models + ) + except Exception as e: + print(f" Warning: FP16 conversion had issues: {e}") + print(" Continuing with partial FP16 conversion...") + + # Save model with external data for large models + print(f"Saving to {output_file}...") + if use_external: + print(" Using external data format (model > 2GB)") + # Create external data filename + external_data_name = output_path.stem + ".onnx.data" + optimized_model.save_model_to_file( + str(output_file), + use_external_data_format=True, + all_tensors_to_one_file=True, + location=external_data_name, + size_threshold=1024, # Externalize tensors > 1KB + convert_attribute=False, + ) + else: + optimized_model.save_model_to_file(str(output_file)) + + # Report fusion results + print() + print("=" * 50) + print("Optimization Results") + print("=" * 50) + + # Count fused operators + import onnx + model = onnx.load(output_file, load_external_data=False) + op_counts = {} + for node in model.graph.node: + op_counts[node.op_type] = op_counts.get(node.op_type, 0) + 1 + + # Report attention-related ops + attention_ops = ['Attention', 'MultiHeadAttention', 'GroupQueryAttention', 'PagedAttention'] + found_attention = False + for op in attention_ops: + if op in op_counts: + print(f" ✅ {op}: {op_counts[op]} (FUSED - Flash Attention compatible)") + found_attention = True + + if not found_attention: + # Check for unfused attention pattern + unfused_ops = ['MatMul', 'Softmax'] + if all(op in op_counts for op in unfused_ops): + print(f" ⚠️ No fused attention operators found") + print(f" MatMul: {op_counts.get('MatMul', 0)}, Softmax: {op_counts.get('Softmax', 0)}") + print(f" Attention patterns may not have been fused") + + # Report total ops + total_ops = sum(op_counts.values()) + print(f"\n Total operators: {total_ops}") + + # Top operators + sorted_ops = sorted(op_counts.items(), key=lambda x: -x[1])[:10] + print(f" Top operators:") + for op, count in sorted_ops: + print(f" {op}: {count}") + + # Calculate output size + print() + out_path = Path(output_file) + out_size = out_path.stat().st_size + ext_data_path = out_path.parent / (out_path.stem + ".onnx.data") + if ext_data_path.exists(): + ext_size = ext_data_path.stat().st_size + print(f" Output model: {out_size / (1024**2):.1f} MB") + print(f" External data: {ext_size / (1024**3):.2f} GB") + print(f" Total size: {(out_size + ext_size) / (1024**3):.2f} GB") + else: + print(f" Output size: {out_size / (1024**3):.2f} GB") + + print() + print("✅ Optimization complete!") + + except Exception as e: + print(f"❌ Optimization failed: {e}") + import traceback + traceback.print_exc() + sys.exit(1) + + +if __name__ == '__main__': + main() diff --git a/models/py/quantize_int4.py b/models/py/quantize_int4.py new file mode 100644 index 0000000..4d7653c --- /dev/null +++ b/models/py/quantize_int4.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python3 +""" +quantize_int4.py - Quantize ONNX model to INT4 (4-bit weight quantization) +""" + +import sys +import os +from pathlib import Path + + +def main(): + input_file = os.environ['INPUT_FILE'] + output_file = os.environ['OUTPUT_FILE'] + block_size = int(os.environ['BLOCK_SIZE']) + has_external = os.environ['HAS_EXTERNAL'] == "true" + + input_path = Path(input_file) + output_path = Path(output_file) + + # Check for INT4 support - use matmul_nbits_quantizer (correct module name) + try: + from onnxruntime.quantization import matmul_nbits_quantizer + from onnxruntime.quantization.matmul_nbits_quantizer import MatMulNBitsQuantizer, DefaultWeightOnlyQuantConfig + print("✓ Found MatMulNBitsQuantizer") + except ImportError as e: + print(f"❌ INT4 quantization not available: {e}") + print("") + print(" Requires ONNX Runtime 1.20+") + print(" pip install onnxruntime>=1.20") + print("") + print(" Or use INT8 quantization instead:") + print(" ./05_quantize_int8.sh ") + print("") + sys.exit(1) + + # Perform INT4 quantization + print("") + print("Performing INT4 quantization...") + + print("Step 1: Loading model...") + import onnx + try: + model = onnx.load(str(input_path), load_external_data=True) + print(f" Loaded model with {len(model.graph.node)} nodes") + except Exception as e: + print(f" Error loading model: {e}") + sys.exit(1) + + print("Step 2: Checking model compatibility...") + + # Check if model has been optimized with FP16 Cast nodes inserted + init_names = {init.name for init in model.graph.initializer} + matmuls = [n for n in model.graph.node if n.op_type == 'MatMul'] + matmuls_with_const_weight = 0 + has_precision_cast = False + + for mm in matmuls: + if len(mm.input) >= 2: + weight_input = mm.input[1] + if weight_input in init_names: + matmuls_with_const_weight += 1 + if 'InsertedPrecisionFreeCast' in weight_input: + has_precision_cast = True + + pct_quantizable = (matmuls_with_const_weight / len(matmuls) * 100) if matmuls else 0 + print(f" MatMul nodes: {len(matmuls)}") + print(f" Quantizable: {matmuls_with_const_weight} ({pct_quantizable:.0f}%)") + + if has_precision_cast or pct_quantizable < 50: + print("") + print(" ⚠ WARNING: This model appears to be FP16-optimized.") + print(" The optimizer inserted Cast nodes that block weight quantization.") + print("") + print(" For INT4 quantization, use the base model BEFORE optimization:") + print(" ./05_quantize_int4.sh ./path/to/model.onnx ./output_int4.onnx") + print("") + print(" Then optimize the INT4 model WITHOUT --float16:") + print(" python3 -m onnxruntime.transformers.optimizer ...") + print("") + if pct_quantizable == 0: + print(" ❌ No quantizable MatMul nodes found. Exiting.") + sys.exit(1) + print(" Continuing with partial quantization...") + print("") + + print(f"Step 3: Creating INT4 quantizer (block_size={block_size})...") + + from onnxruntime.quantization import QuantFormat + + quantizer = MatMulNBitsQuantizer( + model, + block_size=block_size, + is_symmetric=True, + accuracy_level=4, + op_types_to_quantize=("MatMul", "Gather"), # Explicitly quantize MatMul and Gather ops + quant_format=QuantFormat.QOperator, + ) + + print("Step 4: Running quantization...") + print(" This may take several minutes for large models...") + quantizer.process() + + print("Step 5: Saving quantized model...") + use_external_out = has_external or (len(model.graph.initializer) > 100) + quantizer.model.save_model_to_file(str(output_path), use_external_data_format=use_external_out) + + # Calculate and report sizes + print("") + print("Calculating size reduction...") + + def get_model_size(path): + """Get total model size including external data.""" + p = Path(path) + size = p.stat().st_size if p.exists() else 0 + for ext in ['.onnx.data', '.onnx_data', '_data']: + ext_file = p.parent / (p.stem + ext) + if ext_file.exists(): + size += ext_file.stat().st_size + break + return size + + input_size = get_model_size(input_path) + output_size = get_model_size(output_path) + + input_gb = input_size / (1024**3) + output_gb = output_size / (1024**3) + reduction = (1 - output_size / input_size) * 100 if input_size > 0 else 0 + + print(f"") + print(f"✅ INT4 Quantization complete!") + print(f" Input size: {input_gb:.2f} GB") + print(f" Output size: {output_gb:.2f} GB") + print(f" Reduction: {reduction:.1f}%") + print(f" Expected: ~75% reduction for INT4") + + +if __name__ == '__main__': + main() diff --git a/models/py/quantize_int8.py b/models/py/quantize_int8.py new file mode 100644 index 0000000..fa8ab3c --- /dev/null +++ b/models/py/quantize_int8.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 +""" +quantize_int8.py - Quantize ONNX model to INT8 (dynamic quantization) +""" + +import onnx +import os +import sys +from onnxruntime.quantization import quantize_dynamic, QuantType +from onnxruntime.quantization.shape_inference import quant_pre_process +from pathlib import Path + + +def main(): + input_file = os.environ['INPUT_FILE'] + output_file = os.environ['OUTPUT_FILE'] + input_path = Path(input_file) + output_path = Path(output_file) + + print("Quantizing model to INT8...") + print("This may take a while for large models...") + + # Check for external data + external_data_file = input_path.parent / (input_path.stem + ".onnx.data") + external_data_file_alt = input_path.parent / (input_path.stem + ".onnx_data") + has_external_data = external_data_file.exists() or external_data_file_alt.exists() + + if has_external_data: + print("Model has external data, using model path for quantization...") + + # Try preprocessing first + try: + print("Step 1: Preprocessing model...") + preprocessed_file = str(input_path.parent / (input_path.stem + "_preprocessed.onnx")) + + quant_pre_process( + input_model_path=input_file, + output_model_path=preprocessed_file, + skip_symbolic_shape=True, # Skip if symbolic shape inference fails + ) + quantize_input = preprocessed_file + print(" Preprocessing complete") + except Exception as e: + print(f" Preprocessing skipped: {e}") + quantize_input = input_file + + # Perform quantization + try: + print("Step 2: Quantizing to INT8...") + quantize_dynamic( + model_input=quantize_input, + model_output=output_file, + weight_type=QuantType.QInt8, + extra_options={ + "MatMulConstBOnly": True, + }, + use_external_data_format=has_external_data, + ) + except Exception as e: + print(f"Dynamic quantization failed: {e}") + print("Trying with per-channel quantization disabled...") + try: + quantize_dynamic( + model_input=quantize_input, + model_output=output_file, + weight_type=QuantType.QInt8, + per_channel=False, + extra_options={ + "MatMulConstBOnly": True, + }, + use_external_data_format=has_external_data, + ) + except Exception as e2: + print(f"Quantization failed: {e2}") + print("\n❌ INT8 quantization is not supported for this model architecture.") + print(" Consider using FP16 instead (06_convert_fp16.sh)") + sys.exit(1) + + # Cleanup preprocessed file if it exists + preprocessed_path = input_path.parent / (input_path.stem + "_preprocessed.onnx") + if preprocessed_path.exists(): + os.remove(preprocessed_path) + preprocessed_data = preprocessed_path.parent / (preprocessed_path.stem + ".onnx.data") + if preprocessed_data.exists(): + os.remove(preprocessed_data) + + # Calculate sizes + input_size = input_path.stat().st_size + if has_external_data: + if external_data_file.exists(): + input_size += external_data_file.stat().st_size + elif external_data_file_alt.exists(): + input_size += external_data_file_alt.stat().st_size + + output_size = output_path.stat().st_size + output_data = output_path.parent / (output_path.stem + ".onnx.data") + if output_data.exists(): + output_size += output_data.stat().st_size + + input_size_gb = input_size / (1024**3) + output_size_gb = output_size / (1024**3) + reduction = (1 - output_size / input_size) * 100 if input_size > 0 else 0 + + print(f"\n✅ Quantization complete!") + print(f" Input size: {input_size_gb:.2f} GB") + print(f" Output size: {output_size_gb:.2f} GB") + print(f" Reduction: {reduction:.1f}%") + + +if __name__ == '__main__': + main() diff --git a/models/py/run_inference_test.py b/models/py/run_inference_test.py new file mode 100644 index 0000000..22007f3 --- /dev/null +++ b/models/py/run_inference_test.py @@ -0,0 +1,699 @@ +#!/usr/bin/env python3 +""" +run_inference_test.py - Test inference with ONNX Runtime + +Runs text generation to verify the model works correctly. +Uses autoregressive generation with growing KV cache. +""" + +import os +import sys +import onnxruntime as ort +import numpy as np +from pathlib import Path +import time +import json +import subprocess +from transformers import AutoTokenizer + +# Get environment variables +model_dir = Path(os.environ['MODEL_DIR']) +provider = os.environ['PROVIDER'] +prompt = os.environ.get('PROMPT', 'What is 2+2?') +seq_length = int(os.environ.get('SEQ_LENGTH', '256')) # Bucket size +# Max output = bucket size (KV cache = 2*bucket covers input + output) +max_tokens = seq_length +max_kv_len = seq_length # Maximum KV cache length +temperature = float(os.environ.get('TEMPERATURE', '0.0')) +verbose = os.environ.get('VERBOSE', 'false') == 'true' +no_cache = os.environ.get('NO_CACHE', 'false') == 'true' +exhaustive = os.environ.get('EXHAUSTIVE', 'false') == 'true' +offload_copy = os.environ.get('OFFLOAD_COPY', 'true') == 'true' +migraphx_fp16 = os.environ.get('MIGRAPHX_FP16', '0') == '1' +migraphx_save = os.environ.get('MIGRAPHX_SAVE', '1') == '1' +gpu_target = os.environ.get('GPU_TARGET', '') + +# Configure logging +log_level = 0 if verbose else 2 +ort.set_default_logger_severity(log_level) + +if gpu_target: + print(f"GPU target: {gpu_target}") + +# Load export info if available +export_info = {} +export_info_path = model_dir / "export_info.json" +if export_info_path.exists(): + with open(export_info_path) as f: + export_info = json.load(f) + print(f"Export info: {export_info.get('shape_mode', 'unknown')} shapes") + if export_info.get('model_variant'): + print(f"Model: {export_info['model_variant']}") + +# Find model file +model_file = None +for candidate in ["model.onnx", "model_optimized.onnx"]: + if (model_dir / candidate).exists(): + model_file = model_dir / candidate + break + +if model_file is None: + onnx_files = list(model_dir.glob("*.onnx")) + if onnx_files: + model_file = onnx_files[0] + +if model_file is None: + print(f"Error: No .onnx file found in {model_dir}") + sys.exit(1) + +print(f"\nModel file: {model_file}") +print(f"Available providers: {ort.get_available_providers()}") + +# Check GPU memory before loading +try: + result = subprocess.run(['rocm-smi', '--showmeminfo', 'vram'], + capture_output=True, text=True, timeout=5) + if result.returncode == 0: + print("\nGPU Memory before model load:") + for line in result.stdout.strip().split('\n'): + if 'Used' in line or 'GPU' in line: + print(f" {line.strip()}") +except: + pass + +# Enable verbose logging for debugging +if verbose: + # ORT verbose logging + os.environ['ORT_LOG_LEVEL'] = 'VERBOSE' + # MIGraphX verbose logging + os.environ['MIGRAPHX_TRACE_COMPILE'] = '1' + os.environ['MIGRAPHX_TRACE_EVAL'] = '1' + os.environ['MIGRAPHX_TRACE_GPU_ALLOC'] = '1' + # HIP verbose + os.environ['AMD_LOG_LEVEL'] = '4' + os.environ['HIP_TRACE_API'] = '1' + +# Configure session options +sess_options = ort.SessionOptions() +sess_options.log_severity_level = 0 if verbose else log_level # 0=VERBOSE +sess_options.log_verbosity_level = 10 if verbose else 0 + +# CRITICAL: Disable graph optimizations to avoid hipHostRegister issues +# MIGraphX's optimization may be inserting problematic copy operations +sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL +print("Graph optimizations: DISABLED (workaround for MIGraphX hipHostRegister bug)") + +# Enable profiling for detailed timing +if verbose: + sess_options.enable_profiling = True + print("Verbose logging enabled (ORT + MIGraphX + HIP)") + +# Configure provider options +if provider == "MIGraphXExecutionProvider": + cache_path = str(model_dir / "migraphx_cache") + + # Define minimum sequence length to avoid hipHostRegister bug + MIN_SEQ_FOR_MIGRAPHX = 256 # 2048 bytes minimum + + # MIGraphX options MUST be strings, not booleans/integers + # ALWAYS enable offload_copy to fix hipHostRegister failures on small buffers + # (attention_mask at 4KB fails GPU registration without this) + # MIGraphX provider options - trying to work around hipHostRegister bug + provider_options = { + 'device_id': '0', + 'migraphx_fp16_enable': '1' if migraphx_fp16 else '0', + 'migraphx_exhaustive_tune': '0', # Disable exhaustive tuning + 'migraphx_offload_copy': '1', # Should handle small buffers + # Note: migraphx_enable_gpu is not a valid option, removed + } + + print(f"\nAttempting workaround for MIGraphX hipHostRegister bug...") + print(f" - Graph optimizations disabled") + print(f" - offload_copy enabled") + print(f" - Testing with min buffer size: {MIN_SEQ_FOR_MIGRAPHX * 8} bytes") + + if not no_cache: + os.makedirs(cache_path, exist_ok=True) + provider_options['migraphx_model_cache_dir'] = cache_path + print(f"MIGraphX cache: {cache_path}") + + print(f"\nMIGraphX options:") + for k, v in provider_options.items(): + print(f" {k}: {v}") + + providers = [provider] + provider_options_list = [provider_options] + +elif provider == "ROCMExecutionProvider": + providers = [provider] + provider_options_list = [{ + 'device_id': 0, + 'tunable_op_enable': True, + 'tunable_op_tuning_enable': False, + }] +elif provider == "CUDAExecutionProvider": + providers = [provider] + provider_options_list = [{'device_id': 0}] +else: + providers = [provider] + provider_options_list = [{}] + +# Check if we should use IOBinding to avoid hipHostRegister +use_io_binding = provider == "MIGraphXExecutionProvider" +if use_io_binding: + print("\nUsing IOBinding to pre-allocate inputs on GPU (avoids hipHostRegister)") + +# Create session +print(f"\nCreating session with {provider}...") +print(" (First run may take time for MIGraphX compilation)") + +start_load = time.time() + +try: + print(f"\nAttempting to create session with providers: {providers}") + print(f"Provider options: {provider_options_list}") + + session = ort.InferenceSession( + str(model_file), + sess_options, + providers=providers, + provider_options=provider_options_list + ) + load_time = time.time() - start_load + print(f"Session created in {load_time:.2f}s") + +except Exception as e: + print(f"❌ {provider} failed: {e}") + print(f"\n For MIGraphX issues, try:") + print(f" 1. Check GPU target matches: rocminfo | grep gfx") + print(f" 2. Try CPU provider: ./09_run_inference_test.sh {model_dir} CPUExecutionProvider") + print(f"\n Full error:") + import traceback + traceback.print_exc() + raise + +# Verify which provider is actually being used +actual_providers = session.get_providers() +print(f"Session providers: {actual_providers}") + +if provider != "CPUExecutionProvider" and actual_providers == ['CPUExecutionProvider']: + print(f"⚠️ WARNING: Requested {provider} but fell back to CPU!") + print(" This may indicate the model has unsupported operators.") +else: + print(f"✅ Running on: {actual_providers[0]}") + +# Check GPU memory after loading +if provider != "CPUExecutionProvider": + try: + result = subprocess.run(['rocm-smi', '--showmeminfo', 'vram'], + capture_output=True, text=True, timeout=5) + if result.returncode == 0: + print("\nGPU Memory after model load:") + for line in result.stdout.strip().split('\n'): + if 'Used' in line or 'GPU' in line: + print(f" {line.strip()}") + except: + pass + +# Get model input/output info +model_inputs = session.get_inputs() +model_outputs = session.get_outputs() + +print(f"\nModel inputs ({len(model_inputs)}):") +has_kv_cache = False +num_layers = export_info.get('num_layers', 32) +num_kv_heads = export_info.get('num_kv_heads', 8) +head_dim = export_info.get('head_dim', 128) + +# Check for expected inputs +expected_inputs = ['input_ids', 'attention_mask', 'past_seq_len'] +actual_input_names = [inp.name for inp in model_inputs] +has_past_seq_len = 'past_seq_len' in actual_input_names +has_position_ids = 'position_ids' in actual_input_names + +print(f" Expected new signature (past_seq_len): {has_past_seq_len}") +print(f" Old signature (position_ids): {has_position_ids}") + +if not has_past_seq_len and has_position_ids: + print(f"\n⚠️ WARNING: Model still has old signature!") + print(f" You need to RE-EXPORT the model with the updated export_model.py") + print(f" Current model was exported before the past_seq_len changes.") + +for inp in model_inputs[:5]: + shape_str = str(inp.shape) + is_dynamic = any(isinstance(d, str) or d is None or d == -1 for d in inp.shape) + print(f" {inp.name}: {shape_str} {'[dynamic]' if is_dynamic else '[fixed]'}") + if 'past_key' in inp.name or 'cache' in inp.name: + has_kv_cache = True + +if len(model_inputs) > 5: + print(f" ... and {len(model_inputs) - 5} more") + +print(f"\nModel outputs ({len(model_outputs)}):") +for out in model_outputs[:3]: + print(f" {out.name}: {out.shape}") +if len(model_outputs) > 3: + print(f" ... and {len(model_outputs) - 3} more") + +# Load tokenizer +print("\nLoading tokenizer...") +tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) +if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + +# Detect model type from tokenizer/config +model_type = "unknown" +try: + from transformers import AutoConfig + config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) + model_type = getattr(config, 'model_type', 'unknown') +except: + pass + +# Fallback detection from tokenizer +if model_type == "unknown": + if hasattr(tokenizer, 'name_or_path'): + name_lower = tokenizer.name_or_path.lower() + if 'llama' in name_lower: + model_type = 'llama' + elif 'mistral' in name_lower: + model_type = 'mistral' + elif 'qwen' in name_lower: + model_type = 'qwen2' + elif 'phi' in name_lower: + model_type = 'phi3' + +print(f"Detected model type: {model_type}") + +# Detect model dtype +model_dtype = np.float16 # Default for modern models +for inp in model_inputs: + if "float16" in str(inp.type).lower(): + model_dtype = np.float16 + break + elif "float32" in str(inp.type).lower(): + model_dtype = np.float32 +print(f"Model dtype: {model_dtype}") + +# Format prompt using chat template +print(f"\n{'='*60}") +print("USER PROMPT:") +print(f"{'='*60}") +print(prompt) +print(f"{'='*60}") + +# Apply chat template if available +messages = [{"role": "user", "content": prompt}] +formatted_prompt = None + +if hasattr(tokenizer, 'apply_chat_template') and tokenizer.chat_template is not None: + try: + # Use tokenizer's built-in chat template + formatted_prompt = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + print(f"\nUsing tokenizer chat template") + except Exception as e: + print(f"Chat template failed: {e}, using raw prompt") + +# Fallback: manual templates for common models +if formatted_prompt is None: + if model_type in ['llama', 'llama3']: + # Llama 3.x format + formatted_prompt = ( + f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + f"{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ) + print(f"\nUsing Llama 3 chat format") + elif model_type == 'mistral': + # Mistral format + formatted_prompt = f"[INST] {prompt} [/INST]" + print(f"\nUsing Mistral chat format") + elif model_type == 'qwen2': + # Qwen2 format + formatted_prompt = ( + f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" + ) + print(f"\nUsing Qwen2 chat format") + elif model_type == 'phi3': + # Phi-3 format + formatted_prompt = f"<|user|>\n{prompt}<|end|>\n<|assistant|>\n" + print(f"\nUsing Phi-3 chat format") + else: + # Generic fallback + formatted_prompt = prompt + print(f"\nUsing raw prompt (no chat template)") + +print(f"\nFORMATTED PROMPT:") +print("-" * 60) +print(formatted_prompt[:500] + "..." if len(formatted_prompt) > 500 else formatted_prompt) +print("-" * 60) + +# Tokenize formatted prompt +inputs = tokenizer(formatted_prompt, return_tensors="np", add_special_tokens=False) +input_ids = inputs["input_ids"].astype(np.int64) +raw_prompt_len = input_ids.shape[1] +print(f"Formatted prompt tokens: {raw_prompt_len}") + +# Truncate if prompt exceeds max context +if seq_length > 0 and raw_prompt_len > seq_length: + print(f"WARNING: Prompt ({raw_prompt_len}) exceeds max context ({seq_length}), truncating") + input_ids = input_ids[:, -seq_length:] # Keep last seq_length tokens + raw_prompt_len = input_ids.shape[1] + +prompt_len = raw_prompt_len +print(f"Prompt length: {prompt_len}") + +# Sampling function +def sample_token(logits, temperature=0.0): + """Sample next token from logits.""" + if temperature <= 0: + # Greedy + return np.argmax(logits) + else: + # Temperature sampling + logits = logits / temperature + exp_logits = np.exp(logits - np.max(logits)) + probs = exp_logits / np.sum(exp_logits) + return np.random.choice(len(probs), p=probs) + +# ============================================================ +# AUTOREGRESSIVE GENERATION +# ============================================================ +# Use larger batch sizes for prefill to avoid reshape errors with MIGraphX. +# After export fix: attention_mask must have same sequence dimension as input_ids. +# We use a static shape for all operations to avoid recompilation. +# +# filled_kv tracks how many positions contain valid data (0 to KV_LEN). + +print(f"\nGenerating up to {max_tokens} tokens...") +print("-" * 60) + +generated_ids = input_ids[0].tolist() +eos_token_id = tokenizer.eos_token_id + +# Use a larger sequence length for both prefill and decode to avoid reshape errors +# This allows batched prefill processing while maintaining consistent shapes +# CRITICAL: MIGraphX hipHostRegister bug - HIP memory pool minimum is ~2KB +# Using 256 elements (2048 bytes) to avoid the bug +# Note: MIN_SEQ_FOR_MIGRAPHX is defined earlier if using MIGraphX +if 'MIN_SEQ_FOR_MIGRAPHX' not in locals(): + MIN_SEQ_FOR_MIGRAPHX = 256 # Default if not using MIGraphX +PREFILL_SEQ_LEN = max(MIN_SEQ_FOR_MIGRAPHX, min(256, seq_length)) +DECODE_SEQ_LEN = PREFILL_SEQ_LEN # Use same shape for decode (not optimal but avoids recompile) +KV_LEN = seq_length # e.g., 256 - KV cache size + +print(f"Static shapes: prefill_seq={PREFILL_SEQ_LEN}, decode_seq={DECODE_SEQ_LEN}, kv={KV_LEN}") +print(f"Note: decode uses same seq length as prefill to avoid MIGraphX recompilation") + +# Pre-allocate buffers with static shapes +# For prefill: process multiple tokens at once +# For decode: use same shape but only fill first position (inefficient but avoids recompile) +# Note: position_ids is computed internally in ONNX graph from past_seq_len scalar +prefill_input_ids = np.zeros((1, PREFILL_SEQ_LEN), dtype=np.int64) +prefill_attention_mask = np.zeros((1, PREFILL_SEQ_LEN), dtype=np.int64) + +decode_input_ids = np.zeros((1, DECODE_SEQ_LEN), dtype=np.int64) +decode_attention_mask = np.zeros((1, DECODE_SEQ_LEN), dtype=np.int64) + +print(f"Buffers: prefill={prefill_input_ids.shape}, decode={decode_input_ids.shape}") + +# Fixed-size KV cache buffer +kv_cache = {} +for layer_idx in range(num_layers): + kv_cache[layer_idx] = { + 'key': np.zeros((1, num_kv_heads, KV_LEN, head_dim), dtype=model_dtype), + 'value': np.zeros((1, num_kv_heads, KV_LEN, head_dim), dtype=model_dtype), + } + +print(f"KV cache allocated: {num_layers} layers, shape per layer: {kv_cache[0]['key'].shape}") + +# Track how many positions are filled (valid data in the static buffer) +filled_kv = 0 # 0 to KV_LEN + +# Timing +total_start = time.time() +decode_times = [] +new_token_ids = [] +prompt_tokens = generated_ids.copy() + +def run_batch_prefill(tokens, start_position, kv_cache, filled_kv): + """ + Run inference for a batch of tokens during prefill. + + Args: + tokens: List of token IDs (up to PREFILL_SEQ_LEN) + start_position: Starting position index + kv_cache: KV cache dict (will be updated) + filled_kv: Current filled positions in KV cache + + Returns: + logits (for last token), kv_cache, new_filled_kv + """ + batch_size = len(tokens) + assert batch_size <= PREFILL_SEQ_LEN, f"Batch {batch_size} exceeds {PREFILL_SEQ_LEN}" + + # Fill buffers + prefill_input_ids.fill(0) + prefill_attention_mask.fill(0) + + for i, token_id in enumerate(tokens): + prefill_input_ids[0, i] = token_id + prefill_attention_mask[0, i] = 1 # Mark valid positions + + # Create past_seq_len padded tensor (256 elements = 2048 bytes to avoid hipHostRegister failures) + # Only first element is used; rest is padding + past_seq_len_padded = np.zeros(256, dtype=np.int64) + past_seq_len_padded[0] = start_position + + # Build feed dict + feed_dict = { + "input_ids": prefill_input_ids, + "attention_mask": prefill_attention_mask, + "past_seq_len": past_seq_len_padded, + } + + for inp in model_inputs: + if "past_key_values" in inp.name: + layer_idx = int(inp.name.split('.')[1]) + if ".key" in inp.name: + feed_dict[inp.name] = kv_cache[layer_idx]['key'] + elif ".value" in inp.name: + feed_dict[inp.name] = kv_cache[layer_idx]['value'] + + # Run inference + outputs = session.run(None, feed_dict) + + # Extract and store KV cache updates + output_idx = 1 + for layer_idx in range(num_layers): + out_key = outputs[output_idx] + out_value = outputs[output_idx + 1] + + # Copy new KV entries (only valid positions) + for i in range(batch_size): + if filled_kv + i < KV_LEN: + kv_cache[layer_idx]['key'][:, :, filled_kv + i, :] = out_key[:, :, i, :] + kv_cache[layer_idx]['value'][:, :, filled_kv + i, :] = out_value[:, :, i, :] + + output_idx += 2 + + new_filled = min(filled_kv + batch_size, KV_LEN) + + # Return logits for last token + logits = outputs[0] + return logits[0, batch_size - 1, :], kv_cache, new_filled + + +def run_single_decode(token_id, position, kv_cache, filled_kv): + """ + Run inference for single token during decode phase. + Uses same shape as prefill (DECODE_SEQ_LEN) but only fills first position. + + Args: + token_id: Token ID to process + position: Position index + kv_cache: KV cache dict (will be updated) + filled_kv: Current filled positions in KV cache + + Returns: + logits, kv_cache, new_filled_kv + """ + # Fill buffers (only first position used) + decode_input_ids.fill(0) + decode_attention_mask.fill(0) + + decode_input_ids[0, 0] = token_id + decode_attention_mask[0, 0] = 1 # Only current token is valid + + # Create past_seq_len padded tensor (256 elements = 2048 bytes to avoid hipHostRegister failures) + # Only first element is used; rest is padding + past_seq_len_padded = np.zeros(256, dtype=np.int64) + past_seq_len_padded[0] = position + + # Build feed dict + feed_dict = { + "input_ids": decode_input_ids, + "attention_mask": decode_attention_mask, + "past_seq_len": past_seq_len_padded, + } + + for inp in model_inputs: + if "past_key_values" in inp.name: + layer_idx = int(inp.name.split('.')[1]) + if ".key" in inp.name: + feed_dict[inp.name] = kv_cache[layer_idx]['key'] + elif ".value" in inp.name: + feed_dict[inp.name] = kv_cache[layer_idx]['value'] + + # Run inference + outputs = session.run(None, feed_dict) + + # Extract and store KV cache update + output_idx = 1 + for layer_idx in range(num_layers): + out_key = outputs[output_idx] + out_value = outputs[output_idx + 1] + + # Copy new KV entry (only first position is valid) + if filled_kv < KV_LEN: + kv_cache[layer_idx]['key'][:, :, filled_kv, :] = out_key[:, :, 0, :] + kv_cache[layer_idx]['value'][:, :, filled_kv, :] = out_value[:, :, 0, :] + + output_idx += 2 + + new_filled = min(filled_kv + 1, KV_LEN) + + # Return logits for the token + logits = outputs[0] + return logits[0, 0, :], kv_cache, new_filled + + +# ========== PREFILL (BATCHED) ========== +# Process prompt in batches for faster prefill +prefill_start = time.time() + +n_prompt = len(prompt_tokens) +print(f"[Prefill: {n_prompt} tokens in batches of {PREFILL_SEQ_LEN}]") + +position = 0 +for i in range(0, n_prompt, PREFILL_SEQ_LEN): + batch_tokens = prompt_tokens[i:i + PREFILL_SEQ_LEN] + logits, kv_cache, filled_kv = run_batch_prefill(batch_tokens, position, kv_cache, filled_kv) + position += len(batch_tokens) + + if (i + len(batch_tokens)) % (PREFILL_SEQ_LEN * 2) == 0 or i + len(batch_tokens) >= n_prompt: + print(f" [Prefill: {i + len(batch_tokens)}/{n_prompt}, KV: {filled_kv}/{KV_LEN}]", end='\r') + +print() # Newline +prefill_time = time.time() - prefill_start +print(f"[Prefill complete: {len(prompt_tokens)} tokens in {prefill_time*1000:.0f}ms") +print(f" Throughput: {len(prompt_tokens)/prefill_time:.1f} tok/s]") +print(f"[KV filled: {filled_kv}/{KV_LEN}]") +print("\nASSISTANT:") +print("-" * 60) + +# Sample first token from prefill logits +next_token_id = sample_token(logits, temperature) +generated_ids.append(int(next_token_id)) +new_token_ids.append(int(next_token_id)) + +# Print first token +token_str = tokenizer.decode([next_token_id], skip_special_tokens=True) +sys.stdout.write(token_str) +sys.stdout.flush() + +# Track position for decode +current_position = len(prompt_tokens) + +# ========== DECODE ========== +# Each decode step processes one token (uses same shape as prefill for consistency) +for step in range(max_tokens - 1): # -1 because we already generated 1 + # Check stopping conditions + if next_token_id == eos_token_id: + break + if tokenizer.decode([next_token_id]) in ['<|eot_id|>', '<|end|>', '<|im_end|>', '']: + break + + # Check if KV buffer is full + if filled_kv >= KV_LEN: + print(f"\n[KV buffer full at {KV_LEN}, stopping]") + break + + step_start = time.time() + + # Process single token (uses DECODE_SEQ_LEN shape) + logits, kv_cache, filled_kv = run_single_decode( + next_token_id, current_position, kv_cache, filled_kv + ) + + decode_times.append(time.time() - step_start) + current_position += 1 + + # Sample next token + next_token_id = sample_token(logits, temperature) + generated_ids.append(int(next_token_id)) + new_token_ids.append(int(next_token_id)) + + # Print token + token_str = tokenizer.decode([next_token_id], skip_special_tokens=True) + sys.stdout.write(token_str) + sys.stdout.flush() + +print() # New line + +total_time = time.time() - total_start +print() +print("-" * 60) + +# ============================================================ +# RESULTS +# ============================================================ +# Generated tokens count excludes padding +generated_tokens = len(new_token_ids) + +# Decode only the assistant's response (new tokens) +assistant_response = tokenizer.decode(new_token_ids, skip_special_tokens=True).strip() + +print(f"\n{'='*60}") +print("ASSISTANT RESPONSE (clean):") +print(f"{'='*60}") +print(assistant_response) +print(f"{'='*60}") + +# Performance stats +print(f"\n{'='*60}") +print("PERFORMANCE SUMMARY") +print(f"{'='*60}") +print(f"Provider: {actual_providers[0]}") +print(f"Model type: {model_type}") +print(f"Shapes: prefill_seq={PREFILL_SEQ_LEN}, decode_seq={DECODE_SEQ_LEN}, kv={KV_LEN}") +print(f"KV filled: {filled_kv}/{KV_LEN}") +print(f"Prompt tokens: {raw_prompt_len}") +print(f"Generated tokens: {generated_tokens}") +print(f"Total context: {raw_prompt_len + generated_tokens}") +print(f"Temperature: {temperature}") +print(f"-" * 60) +print(f"Model load time: {load_time*1000:.0f} ms") +if prefill_time > 0: + print(f"Prefill time: {prefill_time*1000:.0f} ms ({raw_prompt_len/prefill_time:.1f} tok/s)") +if decode_times: + avg_decode = np.mean(decode_times) * 1000 + print(f"Avg decode time: {avg_decode:.2f} ms/token") + print(f"Decode throughput: {1000/avg_decode:.1f} tokens/sec") +if total_time > 0 and generated_tokens > 0: + print(f"Total gen time: {total_time*1000:.0f} ms") + print(f"Overall tok/sec: {generated_tokens/total_time:.1f}") +print(f"{'='*60}") + +# Check stopping reason +if new_token_ids and new_token_ids[-1] == eos_token_id: + print("\n✅ Generation stopped at EOS token") +elif generated_tokens >= max_tokens: + print(f"\n✅ Generation stopped at max output ({max_tokens} tokens)") +else: + print("\n✅ Generation stopped at model stop token") + +print("\n✅ Text generation complete!") diff --git a/models/py/validate_model.py b/models/py/validate_model.py new file mode 100644 index 0000000..92cfedd --- /dev/null +++ b/models/py/validate_model.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 +""" +validate_model.py - Validate ONNX model +""" + +import onnx +from pathlib import Path +import os +import sys + + +def main(): + model_file = os.environ['MODEL_FILE'] + model_path = Path(model_file) + model_dir = model_path.parent + + # Check for external data files + external_data_file = model_dir / (model_path.stem + ".onnx.data") + external_data_file_alt = model_dir / (model_path.stem + ".onnx_data") + + has_external_data = external_data_file.exists() or external_data_file_alt.exists() + + # Calculate total size including external data + file_size = os.path.getsize(model_file) + if external_data_file.exists(): + file_size += os.path.getsize(external_data_file) + print(f"External data file: {external_data_file}") + elif external_data_file_alt.exists(): + file_size += os.path.getsize(external_data_file_alt) + print(f"External data file: {external_data_file_alt}") + + file_size_gb = file_size / (1024**3) + print(f"Total model size: {file_size_gb:.2f} GB") + + # For models with external data or large models, use path-based validation + if has_external_data or file_size_gb > 2.0: + print("Using path-based validation (external data detected)...") + print("Checking model...") + try: + # Use path-based check for models with external data + onnx.checker.check_model(model_file) + print("✅ Model is valid!") + except onnx.checker.ValidationError as e: + print(f"❌ Validation failed: {e}") + sys.exit(1) + except Exception as e: + # Some versions of onnx may not support all checks + print(f"⚠️ Validation warning: {e}") + print(" Continuing with metadata extraction...") + + # Load without external data just to get metadata + print("\nLoading metadata (without weights)...") + model = onnx.load(model_file, load_external_data=False) + else: + print("Loading model...") + try: + model = onnx.load(model_file, load_external_data=True) + except Exception as e: + print("Trying without external data...") + model = onnx.load(model_file, load_external_data=False) + + print("Checking model...") + try: + onnx.checker.check_model(model) + print("✅ Model is valid!") + except onnx.checker.ValidationError as e: + print(f"❌ Validation failed: {e}") + sys.exit(1) + + print("\nModel info:") + print(f" IR version: {model.ir_version}") + print(f" Opset version: {model.opset_import[0].version}") + print(f" Producer: {model.producer_name} {model.producer_version}") + print(f" Graph name: {model.graph.name}") + print(f" Inputs: {len(model.graph.input)}") + for inp in model.graph.input: + try: + dims = [d.dim_value or d.dim_param for d in inp.type.tensor_type.shape.dim] + print(f" - {inp.name}: {dims}") + except: + print(f" - {inp.name}: (unknown shape)") + print(f" Outputs: {len(model.graph.output)}") + for out in model.graph.output: + try: + dims = [d.dim_value or d.dim_param for d in out.type.tensor_type.shape.dim] + print(f" - {out.name}: {dims}") + except: + print(f" - {out.name}: (unknown shape)") + print(f" Nodes: {len(model.graph.node)}") + print(f" Initializers: {len(model.graph.initializer)}") + + +if __name__ == '__main__': + main() diff --git a/models/test_small_fp16.onnx b/models/test_small_fp16.onnx new file mode 100644 index 0000000..55ca750 Binary files /dev/null and b/models/test_small_fp16.onnx differ diff --git a/models/test_small_model.py b/models/test_small_model.py new file mode 100755 index 0000000..0bfea16 --- /dev/null +++ b/models/test_small_model.py @@ -0,0 +1,41 @@ +import onnx +from onnx import helper, TensorProto +import numpy as np + +# Create a tiny test model with some weights (like an LLM would have) +# This simulates the structure without the size + +# Input: [batch, seq, hidden] +X = helper.make_tensor_value_info('input', TensorProto.FLOAT16, [1, 4, 256]) + +# Weight tensors (simulating model weights) +W1_data = np.random.randn(256, 256).astype(np.float16) +W1 = helper.make_tensor('weight1', TensorProto.FLOAT16, [256, 256], W1_data.tobytes(), raw=True) + +W2_data = np.random.randn(256, 256).astype(np.float16) +W2 = helper.make_tensor('weight2', TensorProto.FLOAT16, [256, 256], W2_data.tobytes(), raw=True) + +# Output +Y = helper.make_tensor_value_info('output', TensorProto.FLOAT16, [1, 4, 256]) + +# Nodes: input -> matmul(W1) -> relu -> matmul(W2) -> output +matmul1 = helper.make_node('MatMul', ['input', 'weight1'], ['hidden1']) +relu = helper.make_node('Relu', ['hidden1'], ['hidden2']) +matmul2 = helper.make_node('MatMul', ['hidden2', 'weight2'], ['output']) + +# Graph +graph = helper.make_graph( + [matmul1, relu, matmul2], + 'test_model', + [X], + [Y], + [W1, W2] # Initializers (weights) +) + +# Model +model = helper.make_model(graph, opset_imports=[helper.make_opsetid('', 14)]) +model.ir_version = 8 + +onnx.save(model, 'test_small_fp16.onnx') +print("Created test_small_fp16.onnx with embedded weights") +print(f"Model has {len(graph.initializer)} initializers (weights)") diff --git a/pypi/onnxruntime_migraphx-1.23.2-cp312-cp312-linux_x86_64.whl b/pypi/onnxruntime_migraphx-1.23.2-cp312-cp312-linux_x86_64.whl new file mode 100644 index 0000000..e1abfb8 --- /dev/null +++ b/pypi/onnxruntime_migraphx-1.23.2-cp312-cp312-linux_x86_64.whl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b3601cc0ae8ef18cbda6c233117b53c80ae77e05f0c222d04afa0b3f6500b344 +size 21096841 diff --git a/test_rdna3_compatibility.sh b/test_rdna3_compatibility.sh new file mode 100644 index 0000000..737c621 --- /dev/null +++ b/test_rdna3_compatibility.sh @@ -0,0 +1,72 @@ +#!/bin/bash + +# RDNA3 GPU Compatibility Test Script +# Tests different execution modes to find the best configuration for your system + +echo "🔧 RDNA3 GPU Compatibility Test" +echo "================================" + +# Check ROCm installation +echo "📋 Checking ROCm installation..." +if command -v rocminfo &> /dev/null; then + echo "✅ ROCm found" + rocminfo | grep "Name:" | head -5 +else + echo "❌ ROCm not found or not in PATH" + exit 1 +fi + +# Check GPU visibility +echo "" +echo "📋 Checking GPU visibility..." +if [ -n "$HIP_VISIBLE_DEVICES" ]; then + echo "HIP_VISIBLE_DEVICES: $HIP_VISIBLE_DEVICES" +else + echo "HIP_VISIBLE_DEVICES: not set (all GPUs visible)" +fi + +if [ -n "$ROCR_VISIBLE_DEVICES" ]; then + echo "ROCR_VISIBLE_DEVICES: $ROCR_VISIBLE_DEVICES" +else + echo "ROCR_VISIBLE_DEVICES: not set (all devices visible)" +fi + +# Test with different configurations +echo "" +echo "🧪 Testing execution modes..." + +# Test 1: Environment variable override +echo "" +echo "Test 1: HSA_OVERRIDE_GFX_VERSION=10.3.0" +export HSA_OVERRIDE_GFX_VERSION=10.3.0 +echo "Environment variable set. Try running your application now." + +# Test 2: Check if integrated GPU is enabled +echo "" +echo "Test 2: Checking for integrated GPU interference..." +rocminfo | grep -i "integrated" && echo "⚠️ Warning: Integrated GPU detected. Consider disabling in BIOS or using ROCR_VISIBLE_DEVICES to exclude it." + +# Test 3: Build and run a simple test +echo "" +echo "Test 3: Building and testing with different compatibility modes..." +echo "Building test project..." + +cd "$(dirname "$0")" +if dotnet build OrtForge.AI.Agent.Console/OrtForge.AI.Agent.Console.csproj -c Release -v q; then + echo "✅ Build successful" + echo "" + echo "🚀 Ready to test! Try running with:" + echo " 1. Standard mode (will likely fail on RDNA3)" + echo " 2. RDNA3 compatible mode (recommended)" + echo " 3. CPU-only mode (fallback)" + echo "" + echo "The runtime factory now defaults to RDNA3 compatible mode." +else + echo "❌ Build failed. Check your .NET installation." +fi + +echo "" +echo "✨ Test complete. If you still have issues:" +echo " 1. Try HSA_OVERRIDE_GFX_VERSION=10.3.0" +echo " 2. Use CPU-only mode for testing" +echo " 3. Check the RDNA3_GPU_COMPATIBILITY.md guide"