diff --git a/StabilityMatrix.Avalonia/ViewModels/PackageManager/PackageCardViewModel.cs b/StabilityMatrix.Avalonia/ViewModels/PackageManager/PackageCardViewModel.cs index 4b5252ac6..2fbc50c6d 100644 --- a/StabilityMatrix.Avalonia/ViewModels/PackageManager/PackageCardViewModel.cs +++ b/StabilityMatrix.Avalonia/ViewModels/PackageManager/PackageCardViewModel.cs @@ -170,7 +170,10 @@ partial void OnPackageChanged(InstalledPackage? value) // Set the extra commands if available from the package var packageExtraCommands = basePackage?.GetExtraCommands(); - ExtraCommands = packageExtraCommands?.Count > 0 ? packageExtraCommands : null; + var visibleExtraCommands = packageExtraCommands + ?.Where(command => command.IsVisible?.Invoke(value) ?? true) + .ToList(); + ExtraCommands = visibleExtraCommands?.Count > 0 ? visibleExtraCommands : null; runningPackageService.RunningPackages.CollectionChanged += RunningPackagesOnCollectionChanged; EventManager.Instance.PackageRelaunchRequested += InstanceOnPackageRelaunchRequested; diff --git a/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs b/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs index 118efa55c..3c031b2cd 100644 --- a/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs +++ b/StabilityMatrix.Core/Helper/Factory/PackageFactory.cs @@ -4,6 +4,7 @@ using StabilityMatrix.Core.Models.Packages; using StabilityMatrix.Core.Python; using StabilityMatrix.Core.Services; +using StabilityMatrix.Core.Services.Rocm; namespace StabilityMatrix.Core.Helper.Factory; @@ -18,6 +19,7 @@ public class PackageFactory : IPackageFactory private readonly IUvManager uvManager; private readonly IPyInstallationManager pyInstallationManager; private readonly IPipWheelService pipWheelService; + private readonly IRocmPackageHelper rocmPackageHelper; /// /// Mapping of package.Name to package @@ -32,7 +34,9 @@ public PackageFactory( IPrerequisiteHelper prerequisiteHelper, IPyInstallationManager pyInstallationManager, IPyRunner pyRunner, - IPipWheelService pipWheelService + IUvManager uvManager, + IPipWheelService pipWheelService, + IRocmPackageHelper rocmPackageHelper ) { this.githubApiCache = githubApiCache; @@ -40,8 +44,10 @@ IPipWheelService pipWheelService this.downloadService = downloadService; this.prerequisiteHelper = prerequisiteHelper; this.pyRunner = pyRunner; + this.uvManager = uvManager; this.pyInstallationManager = pyInstallationManager; this.pipWheelService = pipWheelService; + this.rocmPackageHelper = rocmPackageHelper; this.basePackages = basePackages.ToDictionary(x => x.Name); } @@ -55,7 +61,8 @@ public BasePackage GetNewBasePackage(InstalledPackage installedPackage) downloadService, prerequisiteHelper, pyInstallationManager, - pipWheelService + pipWheelService, + rocmPackageHelper ), "Fooocus" => new Fooocus( githubApiCache, @@ -152,7 +159,8 @@ public BasePackage GetNewBasePackage(InstalledPackage installedPackage) downloadService, prerequisiteHelper, pyInstallationManager, - pipWheelService + pipWheelService, + rocmPackageHelper ), "automatic" => new VladAutomatic( githubApiCache, @@ -280,7 +288,8 @@ public BasePackage GetNewBasePackage(InstalledPackage installedPackage) downloadService, prerequisiteHelper, pyInstallationManager, - pipWheelService + pipWheelService, + rocmPackageHelper ), _ => throw new ArgumentOutOfRangeException(nameof(installedPackage)), }; diff --git a/StabilityMatrix.Core/Helper/HardwareInfo/GpuInfo.cs b/StabilityMatrix.Core/Helper/HardwareInfo/GpuInfo.cs index eedcb556c..6a791e7a9 100644 --- a/StabilityMatrix.Core/Helper/HardwareInfo/GpuInfo.cs +++ b/StabilityMatrix.Core/Helper/HardwareInfo/GpuInfo.cs @@ -1,4 +1,6 @@ -namespace StabilityMatrix.Core.Helper.HardwareInfo; +using StabilityMatrix.Core.Models.Rocm; + +namespace StabilityMatrix.Core.Helper.HardwareInfo; public record GpuInfo { @@ -62,11 +64,7 @@ public bool IsLegacyNvidiaGpu() public bool IsWindowsRocmSupportedGpu() { - var gfx = GetAmdGfxArch(); - if (gfx is null) - return false; - - return gfx.StartsWith("gfx110") || gfx.StartsWith("gfx120") || gfx.Equals("gfx1151"); + return WindowsRocmSupport.IsSupportedGpu(this); } public bool IsAmd => Name?.Contains("amd", StringComparison.OrdinalIgnoreCase) ?? false; @@ -84,7 +82,7 @@ public bool IsWindowsRocmSupportedGpu() return name switch { // RDNA4 - _ when Has("R9700") || Has("9070") => "gfx1201", + _ when Has("R9700") || Has("R9600") || Has("9070") => "gfx1201", _ when Has("9060") => "gfx1200", // RDNA3.5 APUs @@ -100,11 +98,14 @@ _ when Has("740M") || Has("760M") || Has("780M") || Has("Z1") || Has("Z2") => "g _ when Has("7400") || Has("7500") || Has("7600") || Has("7650") || Has("7700S") => "gfx1102", // RDNA3 dGPU Navi32 - _ when Has("7700") || Has("RX 7800") || HasNoSpace("RX7800") => "gfx1101", + _ when Has("7700") || Has("RX 7800") || Has("v710)") || HasNoSpace("RX7800") => "gfx1101", // RDNA3 dGPU Navi31 (incl. Pro) _ when Has("W7800") || Has("7900") || Has("7950") || Has("7990") => "gfx1100", + // RDNA2 Raphael APUs + _ when Has("Raphael") || Has("Radeon Graphics") || Has("AMD Radeon Graphics") => "gfx1036", + // RDNA2 APUs (Rembrandt) _ when Has("660M") || Has("680M") => "gfx1035", @@ -112,6 +113,9 @@ _ when Has("660M") || Has("680M") => "gfx1035", _ when Has("6300") || Has("6400") || Has("6450") || Has("6500") || Has("6550") || Has("6500M") => "gfx1034", + // RDNA2 Steam Deck APU + _ when Has("Van Gogh") || Has("Sephiroth") => "gfx1033", + // RDNA2 Navi23 _ when Has("6600") || Has("6650") || Has("6700S") || Has("6800S") || Has("6600M") => "gfx1032", @@ -119,8 +123,21 @@ _ when Has("6600") || Has("6650") || Has("6700S") || Has("6800S") || Has("6600M" _ when Has("6700") || Has("6750") || Has("6800M") || Has("6850M") => "gfx1031", // RDNA2 Navi21 (big die) - _ when Has("6800") || Has("6900") || Has("6950") => "gfx1030", + _ when Has("6800") || Has("6900") || Has("6950") || Has("v620") => "gfx1030", + + // RDNA1 Navi10 XTX + _ when Has("5500") => "gfx1012", + + //RDNA1 Pro Card + _ when Has("v520") => "gfx1011", + + // RDNA1 Navi10 XT + _ when Has("5600") || Has("5700") => "gfx1010", + // Vega/GCN5 Dedicated GPUs + _ when Has("rx vega") || Has("vega 64") || Has("vega 56") || Has("vega frontier") => "gfx900", + _ when Has("radeon vii") || HasNoSpace("radeonvii") || Has("pro vii") || HasNoSpace("provii") => + "gfx906", _ => null, }; diff --git a/StabilityMatrix.Core/Helper/HardwareInfo/HardwareHelper.cs b/StabilityMatrix.Core/Helper/HardwareInfo/HardwareHelper.cs index 8458c730b..93f093d41 100644 --- a/StabilityMatrix.Core/Helper/HardwareInfo/HardwareHelper.cs +++ b/StabilityMatrix.Core/Helper/HardwareInfo/HardwareHelper.cs @@ -7,6 +7,7 @@ using Microsoft.Win32; using NLog; using StabilityMatrix.Core.Extensions; +using StabilityMatrix.Core.Models.Rocm; namespace StabilityMatrix.Core.Helper.HardwareInfo; @@ -316,12 +317,11 @@ public static bool HasAmdGpu() return IterGpuInfo().Any(gpu => gpu.IsAmd); } - public static bool HasWindowsRocmSupportedGpu() => - IterGpuInfo().Any(gpu => gpu is { IsAmd: true, Name: not null } && gpu.IsWindowsRocmSupportedGpu()); + public static bool HasWindowsRocmSupportedGpu() => IterGpuInfo().Any(WindowsRocmSupport.IsSupportedGpu); public static GpuInfo? GetWindowsRocmSupportedGpu() { - return IterGpuInfo().FirstOrDefault(gpu => gpu.IsWindowsRocmSupportedGpu()); + return IterGpuInfo().FirstOrDefault(WindowsRocmSupport.IsSupportedGpu); } public static bool HasIntelGpu() => IterGpuInfo().Any(gpu => gpu.IsIntel); diff --git a/StabilityMatrix.Core/Models/ExtraPackageCommand.cs b/StabilityMatrix.Core/Models/ExtraPackageCommand.cs index 87398a8c3..fa7219dbc 100644 --- a/StabilityMatrix.Core/Models/ExtraPackageCommand.cs +++ b/StabilityMatrix.Core/Models/ExtraPackageCommand.cs @@ -4,4 +4,5 @@ public class ExtraPackageCommand { public required string CommandName { get; set; } public required Func Command { get; set; } + public Func? IsVisible { get; set; } } diff --git a/StabilityMatrix.Core/Models/PackageModification/InstallWindowsRocmPackageCommandStep.cs b/StabilityMatrix.Core/Models/PackageModification/InstallWindowsRocmPackageCommandStep.cs new file mode 100644 index 000000000..a59f2bc53 --- /dev/null +++ b/StabilityMatrix.Core/Models/PackageModification/InstallWindowsRocmPackageCommandStep.cs @@ -0,0 +1,293 @@ +using StabilityMatrix.Core.Helper; +using StabilityMatrix.Core.Models.FileInterfaces; +using StabilityMatrix.Core.Models.Progress; +using StabilityMatrix.Core.Models.Rocm; +using StabilityMatrix.Core.Processes; +using StabilityMatrix.Core.Python; +using StabilityMatrix.Core.Services; +using StabilityMatrix.Core.Services.Rocm; + +namespace StabilityMatrix.Core.Models.PackageModification; + +public enum WindowsRocmPackageCommandType +{ + SageAttention, + DevelopmentSdk, + BitsAndBytes, + FlashAttention, +} + +public class InstallWindowsRocmPackageCommandStep( + IDownloadService downloadService, + IPyInstallationManager pyInstallationManager, + IPrerequisiteHelper prerequisiteHelper, + IRocmPackageHelper rocmPackageHelper +) : IPackageStep +{ + private const string BitsAndBytesWheelUrl = + "https://github.com/0xDELUXA/bitsandbytes_win_rocm/releases/download/0.50.0.dev0-py3-rocm7-win_amd64_all/bitsandbytes-0.50.0.dev0-cp312-cp312-win_amd64.whl"; + private const string AmdAiterWheelUrl = + "https://github.com/0xDELUXA/flash-attention/releases/download/v2.8.4_win-rocm/amd_aiter-0.0.0-py3-none-win_amd64.whl"; + private const string FlashAttentionWheelUrl = + "https://github.com/0xDELUXA/flash-attention/releases/download/v2.8.4_win-rocm/flash_attn-2.8.4-py3-none-win_amd64.whl"; + private const string TritonWindowsVersion = "3.6.0.post25"; + private const string SageAttentionVersion = "1.0.6"; + + private const string AttnQkInt8PerBlockUrl = + "https://raw.githubusercontent.com/patientx/ComfyUI-Zluda/refs/heads/master/comfy/customzluda/sa/attn_qk_int8_per_block.py"; + + private const string AttnQkInt8PerBlockCausalUrl = + "https://raw.githubusercontent.com/patientx/ComfyUI-Zluda/refs/heads/master/comfy/customzluda/sa/attn_qk_int8_per_block_causal.py"; + + private const string QuantPerBlockUrl = + "https://raw.githubusercontent.com/patientx/ComfyUI-Zluda/refs/heads/master/comfy/customzluda/sa/quant_per_block.py"; + + public required InstalledPackage InstalledPackage { get; init; } + public required DirectoryPath WorkingDirectory { get; init; } + public required WindowsRocmPackageCommandType CommandType { get; init; } + public IReadOnlyDictionary? EnvironmentVariables { get; init; } + + public string ProgressTitle => CommandType switch + { + WindowsRocmPackageCommandType.SageAttention => "Installing Windows ROCm SageAttention", + WindowsRocmPackageCommandType.DevelopmentSdk => "Installing Windows ROCm Development SDK", + WindowsRocmPackageCommandType.BitsAndBytes => "Installing Windows ROCm bitsandbytes", + WindowsRocmPackageCommandType.FlashAttention => "Installing Windows ROCm Flash Attention", + _ => "Running Windows ROCm package command", + }; + + public async Task ExecuteAsync(IProgress? progress = null) + { + if (!global::System.OperatingSystem.IsWindows()) + { + throw new PlatformNotSupportedException( + "Windows ROCm package commands are only supported on Windows." + ); + } + + if (InstalledPackage.FullPath is null) + { + throw new InvalidOperationException("Installed package path is not available."); + } + + var venvDir = WorkingDirectory.JoinDir("venv"); + if (!venvDir.Exists) + { + throw new DirectoryNotFoundException($"ComfyUI venv was not found at '{venvDir.FullPath}'."); + } + + var pyVersion = PyVersion.Parse(InstalledPackage.PythonVersion); + if (pyVersion.StringValue == "0.0.0") + { + pyVersion = PyInstallationManager.Python_3_10_11; + } + + var baseInstall = !string.IsNullOrWhiteSpace(InstalledPackage.PythonVersion) + ? new PyBaseInstall( + await pyInstallationManager.GetInstallationAsync(pyVersion).ConfigureAwait(false) + ) + : PyBaseInstall.Default; + + await using var venvRunner = baseInstall.CreateVenvRunner( + venvDir, + workingDirectory: WorkingDirectory, + environmentVariables: EnvironmentVariables + ); + + switch (CommandType) + { + case WindowsRocmPackageCommandType.SageAttention: + await ExecuteSageAttentionAsync(venvRunner, progress).ConfigureAwait(false); + break; + case WindowsRocmPackageCommandType.DevelopmentSdk: + await ExecuteDevelopmentSdkAsync(venvRunner, progress).ConfigureAwait(false); + break; + case WindowsRocmPackageCommandType.BitsAndBytes: + await ExecuteBitsAndBytesAsync(venvRunner, pyVersion, progress).ConfigureAwait(false); + break; + case WindowsRocmPackageCommandType.FlashAttention: + await ExecuteFlashAttentionAsync(venvRunner, progress).ConfigureAwait(false); + break; + default: + throw new InvalidOperationException( + $"Unsupported Windows ROCm package command type: {CommandType}." + ); + } + } + + private void EnsureRocmCompatibility() + { + var compatibility = rocmPackageHelper.GetCompatibility(ComfyWindowsRocmProfile.Profile); + if (!compatibility.IsCompatible) + { + throw new InvalidOperationException( + compatibility.FailureReason + ?? "Windows ROCm package commands require a supported Windows ROCm machine state." + ); + } + } + + private async Task EnsureVcBuildToolsAsync(IProgress? progress) + { + if (!prerequisiteHelper.IsVcBuildToolsInstalled) + { + await prerequisiteHelper + .InstallPackageRequirements([PackagePrerequisite.VcBuildTools], progress: progress) + .ConfigureAwait(false); + } + } + + private async Task ExecuteDevelopmentSdkAsync( + IPyVenvRunner venvRunner, + IProgress? progress + ) + { + EnsureRocmCompatibility(); + await rocmPackageHelper.EnsureWindowsSdkDevelAsync(venvRunner, progress).ConfigureAwait(false); + } + + private async Task ExecuteSageAttentionAsync( + IPyVenvRunner venvRunner, + IProgress? progress + ) + { + EnsureRocmCompatibility(); + await EnsureVcBuildToolsAsync(progress).ConfigureAwait(false); + await rocmPackageHelper.EnsureWindowsSdkDevelAsync(venvRunner, progress).ConfigureAwait(false); + + progress?.Report( + new ProgressReport( + -1f, + "Installing triton-windows for Windows ROCm SageAttention...", + isIndeterminate: true + ) + ); + await venvRunner.PipInstall($"triton-windows=={TritonWindowsVersion}").ConfigureAwait(false); + + progress?.Report( + new ProgressReport(-1f, "Installing SageAttention for Windows ROCm...", isIndeterminate: true) + ); + await venvRunner.PipInstall($"--no-deps sageattention=={SageAttentionVersion}").ConfigureAwait(false); + + var sageAttentionDir = WorkingDirectory.JoinDir("venv", "Lib", "site-packages", "sageattention"); + if (!sageAttentionDir.Exists) + { + throw new DirectoryNotFoundException( + $"Installed SageAttention package path was not found at '{sageAttentionDir.FullPath}'." + ); + } + + progress?.Report( + new ProgressReport(-1f, "Patching SageAttention for Windows ROCm...", isIndeterminate: true) + ); + + await DownloadAndReplaceFileAsync( + sageAttentionDir, + "attn_qk_int8_per_block.py", + AttnQkInt8PerBlockUrl, + progress + ) + .ConfigureAwait(false); + await DownloadAndReplaceFileAsync( + sageAttentionDir, + "attn_qk_int8_per_block_causal.py", + AttnQkInt8PerBlockCausalUrl, + progress + ) + .ConfigureAwait(false); + await DownloadAndReplaceFileAsync(sageAttentionDir, "quant_per_block.py", QuantPerBlockUrl, progress) + .ConfigureAwait(false); + } + + private async Task ExecuteBitsAndBytesAsync( + IPyVenvRunner venvRunner, + PyVersion pyVersion, + IProgress? progress + ) + { + EnsureRocmCompatibility(); + + if (pyVersion.Major != 3 || pyVersion.Minor != 12) + { + throw new InvalidOperationException( + $"Windows ROCm bitsandbytes is only supported on Python 3.12.x (detected version: {pyVersion})." + ); + } + + progress?.Report( + new ProgressReport( + -1f, + "Installing bitsandbytes for Windows ROCm...", + isIndeterminate: true + ) + ); + await venvRunner.PipInstall(BitsAndBytesWheelUrl).ConfigureAwait(false); + } + + private async Task ExecuteFlashAttentionAsync( + IPyVenvRunner venvRunner, + IProgress? progress + ) + { + EnsureRocmCompatibility(); + + progress?.Report( + new ProgressReport( + -1f, + "Installing Flash Attention dependencies for Windows ROCm...", + isIndeterminate: true + ) + ); + await venvRunner.PipInstall(AmdAiterWheelUrl).ConfigureAwait(false); + + progress?.Report( + new ProgressReport( + -1f, + "Installing Flash Attention for Windows ROCm...", + isIndeterminate: true + ) + ); + await venvRunner.PipInstall(FlashAttentionWheelUrl).ConfigureAwait(false); + } + + private async Task DownloadAndReplaceFileAsync( + DirectoryPath sageAttentionDir, + string fileName, + string sourceUrl, + IProgress? progress + ) + { + var targetFile = sageAttentionDir.JoinFile(fileName); + if (!targetFile.Exists) + { + throw new FileNotFoundException( + $"Expected SageAttention file '{fileName}' was not found.", + targetFile.FullPath + ); + } + + var backupFile = sageAttentionDir.JoinFile($"{fileName}.bak"); + if (!backupFile.Exists) + { + await backupFile + .WriteAllTextAsync(await targetFile.ReadAllTextAsync().ConfigureAwait(false)) + .ConfigureAwait(false); + } + + var tempFile = WorkingDirectory.JoinFile($"sm-rocm-sage-{fileName}.tmp"); + await downloadService.DownloadToFileAsync(sourceUrl, tempFile, progress).ConfigureAwait(false); + + try + { + var replacementContent = await tempFile.ReadAllTextAsync().ConfigureAwait(false); + await targetFile.WriteAllTextAsync(replacementContent).ConfigureAwait(false); + } + finally + { + if (tempFile.Exists) + { + await tempFile.DeleteAsync().ConfigureAwait(false); + } + } + } +} \ No newline at end of file diff --git a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs index a4c34649d..6c2db1212 100644 --- a/StabilityMatrix.Core/Models/Packages/ComfyUI.cs +++ b/StabilityMatrix.Core/Models/Packages/ComfyUI.cs @@ -13,9 +13,11 @@ using StabilityMatrix.Core.Models.Packages.Config; using StabilityMatrix.Core.Models.Packages.Extensions; using StabilityMatrix.Core.Models.Progress; +using StabilityMatrix.Core.Models.Rocm; using StabilityMatrix.Core.Processes; using StabilityMatrix.Core.Python; using StabilityMatrix.Core.Services; +using StabilityMatrix.Core.Services.Rocm; namespace StabilityMatrix.Core.Models.Packages; @@ -26,7 +28,8 @@ public class ComfyUI( IDownloadService downloadService, IPrerequisiteHelper prerequisiteHelper, IPyInstallationManager pyInstallationManager, - IPipWheelService pipWheelService + IPipWheelService pipWheelService, + IRocmPackageHelper? rocmPackageHelper = null ) : BaseGitPackage( githubApi, @@ -38,6 +41,7 @@ IPipWheelService pipWheelService ) { private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); + public override string Name => "ComfyUI"; public override string DisplayName { get; set; } = "ComfyUI"; public override string Author => "comfyanonymous"; @@ -247,7 +251,7 @@ IPipWheelService pipWheelService Name = "Enable DirectML", Type = LaunchOptionType.Bool, InitialValue = - !HardwareHelper.HasWindowsRocmSupportedGpu() + !HasWindowsRocmSupport() && HardwareHelper.PreferDirectMLOrZluda() && this is not ComfyZluda, Options = ["--directml"], @@ -264,7 +268,9 @@ IPipWheelService pipWheelService { Name = "Cross Attention Method", Type = LaunchOptionType.Bool, - InitialValue = "--use-pytorch-cross-attention", + InitialValue = DefaultToQuadCrossAttention() + ? "--use-quad-cross-attention" // For Legacy AMD GPUs. + : "--use-pytorch-cross-attention", Options = [ "--use-split-cross-attention", @@ -335,6 +341,52 @@ public override List GetExtraCommands() ); } + if (Compat.IsWindows && HasWindowsRocmSupport()) + { + commands.Add( + new ExtraPackageCommand + { + CommandName = "Install Triton and SageAttention (ROCm)", + Command = InstallWindowsRocmSageAttention, + } + ); + + commands.Add( + new ExtraPackageCommand + { + CommandName = "Install Flash Attention (ROCm)", + Command = InstallWindowsRocmFlashAttention, + IsVisible = _ => + WindowsRocmSupport.IsLegacyArchitecture( + GetWindowsRocmCompatibility().ResolvedGfxArch + ), + } + ); + + commands.Add( + new ExtraPackageCommand + { + CommandName = "Install ROCm Development SDK", + Command = InstallWindowsRocmDevelopmentSdk, + } + ); + + commands.Add( + new ExtraPackageCommand + { + CommandName = "Install bitsandbytes (ROCm)", + Command = InstallWindowsRocmBitsAndBytes, + IsVisible = installedPackage => + { + if (!PyVersion.TryParse(installedPackage.PythonVersion, out var pyVersion)) + return false; + + return pyVersion.Major == 3 && pyVersion.Minor == 12; + }, + } + ); + } + if (!Compat.IsMacOS && SettingsManager.Settings.PreferredGpu?.ComputeCapabilityValue is >= 7.5m) { commands.Add( @@ -362,69 +414,39 @@ public override async Task InstallPackage( .ConfigureAwait(false); var torchIndex = options.PythonOptions.TorchIndex ?? GetRecommendedTorchVersion(); - var gfxArch = - SettingsManager.Settings.PreferredGpu?.GetAmdGfxArch() - ?? HardwareHelper.GetWindowsRocmSupportedGpu()?.GetAmdGfxArch(); - - // Special case for Windows ROCm Nightly builds - if ( - Compat.IsWindows - && !string.IsNullOrWhiteSpace(gfxArch) - && torchIndex is TorchIndex.Rocm - && options.PythonOptions.PythonVersion >= PyVersion.Parse("3.11.0") - ) + var isLegacyNvidia = + torchIndex == TorchIndex.Cuda + && ( + SettingsManager.Settings.PreferredGpu?.IsLegacyNvidiaGpu() + ?? HardwareHelper.HasLegacyNvidiaGpu() + ); + + if (Compat.IsWindows && torchIndex == TorchIndex.Rocm && HasWindowsRocmSupport()) { - var config = new PipInstallConfig + // This is an internal guard for a wiring/configuration failure. + // It can only trigger when Windows ROCm support was detected, but this ComfyUI instance was created + // without the shared ROCm helper (for example via a manual construction path that omitted the dependency). + if (rocmPackageHelper is null) { - RequirementsFilePaths = ["requirements.txt"], - ExtraPipArgs = ["numpy<2"], - SkipTorchInstall = true, - PostInstallPipArgs = ["typing-extensions>=4.15.0"], - }; - await StandardPipInstallProcessAsync( + throw new InvalidOperationException( + "Windows ROCm installation encountered an internal configuration error [rocmPackageHelper is null]. Please restart Stability Matrix and try again. If the issue persists, please report it to Stability Matrix." + ); + } + + await rocmPackageHelper + .InstallWindowsNativePackageAsync( venvRunner, - options, + installLocation, installedPackage, - config, - onConsoleOutput, + ComfyWindowsRocmProfile.Profile, progress, + onConsoleOutput, cancellationToken ) .ConfigureAwait(false); - - progress?.Report( - new ProgressReport(-1f, "Installing ROCm nightly torch...", isIndeterminate: true) - ); - var indexUrl = gfxArch switch - { - "gfx1150" => "https://rocm.nightlies.amd.com/v2-staging/gfx1150", // Strix/Gorgon Point - "gfx1151" => "https://rocm.nightlies.amd.com/v2/gfx1151", // Strix Halo - _ when gfxArch.StartsWith("gfx110") => "https://rocm.nightlies.amd.com/v2/gfx110X-all", - _ when gfxArch.StartsWith("gfx120") => "https://rocm.nightlies.amd.com/v2/gfx120X-all", - _ => throw new ArgumentOutOfRangeException( - nameof(gfxArch), - $"Unsupported GFX Arch: {gfxArch}" - ), - }; - - var torchPipArgs = new PipInstallArgs() - .AddArgs("--pre", "--upgrade") - .WithTorch() - .WithTorchVision() - .WithTorchAudio() - .AddArgs("--index-url", indexUrl); - - await venvRunner.PipInstall(torchPipArgs, onConsoleOutput).ConfigureAwait(false); } - else // Standard installation path for all other cases + else { - var isLegacyNvidia = - torchIndex == TorchIndex.Cuda - && ( - SettingsManager.Settings.PreferredGpu?.IsLegacyNvidiaGpu() - ?? HardwareHelper.HasLegacyNvidiaGpu() - ); - var config = new PipInstallConfig { RequirementsFilePaths = ["requirements.txt"], @@ -448,47 +470,55 @@ await StandardPipInstallProcessAsync( .ConfigureAwait(false); } - try + if (!(Compat.IsWindows && torchIndex == TorchIndex.Rocm && HasWindowsRocmSupport())) { - var sageVersion = await venvRunner.PipShow("sageattention").ConfigureAwait(false); - var torchVersion = await venvRunner.PipShow("torch").ConfigureAwait(false); - - if (torchVersion is not null && sageVersion is not null) + try { - var version = torchVersion.Version; - var plusPos = version.IndexOf('+'); - var index = plusPos >= 0 ? version[(plusPos + 1)..] : string.Empty; - var versionWithoutIndex = plusPos >= 0 ? version[..plusPos] : version; + var sageVersion = await venvRunner.PipShow("sageattention").ConfigureAwait(false); + var torchVersion = await venvRunner.PipShow("torch").ConfigureAwait(false); - if ( - !sageVersion.Version.Contains(index) || !sageVersion.Version.Contains(versionWithoutIndex) - ) + if (torchVersion is not null && sageVersion is not null) { - progress?.Report( - new ProgressReport(-1f, "Updating SageAttention...", isIndeterminate: true) - ); + var version = torchVersion.Version; + var plusPos = version.IndexOf('+'); + var index = plusPos >= 0 ? version[(plusPos + 1)..] : string.Empty; + var versionWithoutIndex = plusPos >= 0 ? version[..plusPos] : version; - var step = new InstallSageAttentionStep( - downloadService, - prerequisiteHelper, - pyInstallationManager + if ( + !sageVersion.Version.Contains(index) + || !sageVersion.Version.Contains(versionWithoutIndex) ) { - InstalledPackage = installedPackage, - IsBlackwellGpu = - SettingsManager.Settings.PreferredGpu?.IsBlackwellGpu() - ?? HardwareHelper.HasBlackwellGpu(), - WorkingDirectory = installLocation, - EnvironmentVariables = GetEnvVars(venvRunner.EnvironmentVariables), - }; - - await step.ExecuteAsync(progress).ConfigureAwait(false); + progress?.Report( + new ProgressReport(-1f, "Updating SageAttention...", isIndeterminate: true) + ); + + var step = new InstallSageAttentionStep( + downloadService, + prerequisiteHelper, + pyInstallationManager + ) + { + InstalledPackage = installedPackage, + IsBlackwellGpu = + SettingsManager.Settings.PreferredGpu?.IsBlackwellGpu() + ?? HardwareHelper.HasBlackwellGpu(), + WorkingDirectory = installLocation, + EnvironmentVariables = GetEnvVars( + venvRunner.EnvironmentVariables, + installLocation, + installedPackage + ), + }; + + await step.ExecuteAsync(progress).ConfigureAwait(false); + } } } - } - catch (Exception e) - { - Logger.Error(e, "Failed to verify/update SageAttention after installation"); + catch (Exception e) + { + Logger.Error(e, "Failed to verify/update SageAttention after installation"); + } } // Install Comfy Manager (built-in to ComfyUI) @@ -529,7 +559,7 @@ public override async Task RunPackage( await SetupVenv(installLocation, pythonVersion: PyVersion.Parse(installedPackage.PythonVersion)) .ConfigureAwait(false); - VenvRunner.UpdateEnvironmentVariables(GetEnvVars); + VenvRunner.UpdateEnvironmentVariables(env => GetEnvVars(env, installLocation, installedPackage)); // Check for old NVIDIA driver version with cu130 installations var isNvidia = SettingsManager.Settings.PreferredGpu?.IsNvidia ?? HardwareHelper.HasNvidiaGpu(); @@ -584,6 +614,8 @@ older torch index (e.g. cu128) } } + var handledFirstConsoleOutput = false; + VenvRunner.RunDetached( [Path.Combine(installLocation, options.Command ?? LaunchCommand), .. options.Arguments], HandleConsoleOutput, @@ -596,6 +628,12 @@ void HandleConsoleOutput(ProcessOutput s) { onConsoleOutput?.Invoke(s); + if (!handledFirstConsoleOutput) + { + handledFirstConsoleOutput = true; + EmitWindowsRocmLaunchNotice(installedPackage, onConsoleOutput); + } + if (!s.Text.Contains("To see the GUI go to", StringComparison.OrdinalIgnoreCase)) return; @@ -609,17 +647,37 @@ void HandleConsoleOutput(ProcessOutput s) } } + private void EmitWindowsRocmLaunchNotice( + InstalledPackage installedPackage, + Action? onConsoleOutput + ) + { + if (rocmPackageHelper is null) + return; + + if (!ShouldShowWindowsRocmLaunchNotice(installedPackage)) + return; + + foreach (var line in rocmPackageHelper.GetWindowsLaunchNoticeLines()) + { + onConsoleOutput?.Invoke(ProcessOutput.FromStdOutLine($"{line}{Environment.NewLine}")); + } + } + + private bool ShouldShowWindowsRocmLaunchNotice(InstalledPackage installedPackage) + { + if (!Compat.IsWindows || !HasWindowsRocmSupport()) + return false; + + var torchIndex = installedPackage.PreferredTorchIndex ?? GetRecommendedTorchVersion(); + return torchIndex == TorchIndex.Rocm; + } + public override TorchIndex GetRecommendedTorchVersion() { var preferRocm = (Compat.IsLinux && (SettingsManager.Settings.PreferredGpu?.IsAmd ?? HardwareHelper.PreferRocm())) - || ( - Compat.IsWindows - && ( - SettingsManager.Settings.PreferredGpu?.IsWindowsRocmSupportedGpu() - ?? HardwareHelper.HasWindowsRocmSupportedGpu() - ) - ); + || HasWindowsRocmSupport(); if (AvailableTorchIndices.Contains(TorchIndex.Rocm) && preferRocm) { @@ -629,6 +687,33 @@ public override TorchIndex GetRecommendedTorchVersion() return base.GetRecommendedTorchVersion(); } + /// Uses the shared ROCm helper for Windows ROCm eligibility checks so ComfyUI does not maintain its own support matrix. + private bool HasWindowsRocmSupport() + { + return GetWindowsRocmCompatibility().IsCompatible; + } + + private RocmCompatibilityResult GetWindowsRocmCompatibility() + { + if (!Compat.IsWindows || rocmPackageHelper is null) + { + return new RocmCompatibilityResult { IsCompatible = false }; + } + + return rocmPackageHelper.GetCompatibility(ComfyWindowsRocmProfile.Profile); + } + + /// Defaults legacy Windows ROCm GPUs to quad cross-attention because PyTorch cross-attention is considerably slower + /// and not as supported on older AMD architectures. + private bool DefaultToQuadCrossAttention() + { + var compatibility = GetWindowsRocmCompatibility(); + if (!compatibility.IsCompatible) + return false; + + return WindowsRocmSupport.PreferLegacyAttentionFallback(compatibility.ResolvedGfxArch); + } + public override IPackageExtensionManager ExtensionManager => new ComfyExtensionManager(this, settingsManager); @@ -904,11 +989,170 @@ await PipWheelService if (runner.Failed) return; + await EnableSageAttentionAsync(installedPackage).ConfigureAwait(false); + } + + private async Task InstallWindowsRocmSageAttention(InstalledPackage? installedPackage) + { + if (installedPackage?.FullPath is null) + return; + + var runner = new PackageModificationRunner + { + ShowDialogOnStart = true, + ModificationCompleteMessage = "Windows ROCm SageAttention installed successfully", + }; + EventManager.Instance.OnPackageInstallProgressAdded(runner); + + var baseEnvironment = ImmutableDictionary.CreateRange(SettingsManager.Settings.EnvironmentVariables); + var environmentVariables = GetEnvVars(baseEnvironment, installedPackage.FullPath, installedPackage); + + await runner + .ExecuteSteps( + [ + new InstallWindowsRocmPackageCommandStep( + downloadService, + pyInstallationManager, + prerequisiteHelper, + rocmPackageHelper + ?? throw new InvalidOperationException( + "Windows ROCm SageAttention installation encountered an internal configuration error [rocmPackageHelper is null]." + ) + ) + { + CommandType = WindowsRocmPackageCommandType.SageAttention, + InstalledPackage = installedPackage, + WorkingDirectory = new DirectoryPath(installedPackage.FullPath), + EnvironmentVariables = environmentVariables, + }, + ] + ) + .ConfigureAwait(false); + + if (runner.Failed) + return; + + await EnableSageAttentionAsync(installedPackage).ConfigureAwait(false); + } + + private async Task InstallWindowsRocmDevelopmentSdk(InstalledPackage? installedPackage) + { + if (installedPackage?.FullPath is null) + return; + + var runner = new PackageModificationRunner + { + ShowDialogOnStart = true, + ModificationCompleteMessage = "Windows ROCm Development SDK installed successfully", + }; + EventManager.Instance.OnPackageInstallProgressAdded(runner); + + await runner + .ExecuteSteps( + [ + new InstallWindowsRocmPackageCommandStep( + downloadService, + pyInstallationManager, + prerequisiteHelper, + rocmPackageHelper + ?? throw new InvalidOperationException( + "Windows ROCm SDK installation encountered an internal configuration error [rocmPackageHelper is null]." + ) + ) + { + CommandType = WindowsRocmPackageCommandType.DevelopmentSdk, + InstalledPackage = installedPackage, + WorkingDirectory = new DirectoryPath(installedPackage.FullPath), + }, + ] + ) + .ConfigureAwait(false); + } + + private async Task InstallWindowsRocmBitsAndBytes(InstalledPackage? installedPackage) + { + if (installedPackage?.FullPath is null) + return; + + var runner = new PackageModificationRunner + { + ShowDialogOnStart = true, + ModificationCompleteMessage = "Windows ROCm bitsandbytes installed successfully", + }; + EventManager.Instance.OnPackageInstallProgressAdded(runner); + + var baseEnvironment = ImmutableDictionary.CreateRange(SettingsManager.Settings.EnvironmentVariables); + var environmentVariables = GetEnvVars(baseEnvironment, installedPackage.FullPath, installedPackage); + + await runner + .ExecuteSteps( + [ + new InstallWindowsRocmPackageCommandStep( + downloadService, + pyInstallationManager, + prerequisiteHelper, + rocmPackageHelper + ?? throw new InvalidOperationException( + "Windows ROCm bitsandbytes installation encountered an internal configuration error [rocmPackageHelper is null]." + ) + ) + { + CommandType = WindowsRocmPackageCommandType.BitsAndBytes, + InstalledPackage = installedPackage, + WorkingDirectory = new DirectoryPath(installedPackage.FullPath), + EnvironmentVariables = environmentVariables, + }, + ] + ) + .ConfigureAwait(false); + } + + private async Task InstallWindowsRocmFlashAttention(InstalledPackage? installedPackage) + { + if (installedPackage?.FullPath is null) + return; + + var runner = new PackageModificationRunner + { + ShowDialogOnStart = true, + ModificationCompleteMessage = "Windows ROCm Flash Attention installed successfully", + }; + EventManager.Instance.OnPackageInstallProgressAdded(runner); + + var baseEnvironment = ImmutableDictionary.CreateRange(SettingsManager.Settings.EnvironmentVariables); + var environmentVariables = GetEnvVars(baseEnvironment, installedPackage.FullPath, installedPackage); + + await runner + .ExecuteSteps( + [ + new InstallWindowsRocmPackageCommandStep( + downloadService, + pyInstallationManager, + prerequisiteHelper, + rocmPackageHelper + ?? throw new InvalidOperationException( + "Windows ROCm Flash Attention installation encountered an internal configuration error [rocmPackageHelper is null]." + ) + ) + { + CommandType = WindowsRocmPackageCommandType.FlashAttention, + InstalledPackage = installedPackage, + WorkingDirectory = new DirectoryPath(installedPackage.FullPath), + EnvironmentVariables = environmentVariables, + }, + ] + ) + .ConfigureAwait(false); + } + + private async Task EnableSageAttentionAsync(InstalledPackage installedPackage) + { await using var transaction = settingsManager.BeginTransaction(); - var attentionOptions = transaction - .Settings.InstalledPackages.First(x => x.Id == installedPackage.Id) - .LaunchArgs?.Where(opt => opt.Name.Contains("attention")); + var packageInSettings = transaction.Settings.InstalledPackages.First(x => + x.Id == installedPackage.Id + ); + var attentionOptions = packageInSettings.LaunchArgs?.Where(opt => opt.Name.Contains("attention")); if (attentionOptions is not null) { foreach (var option in attentionOptions) @@ -917,9 +1161,9 @@ await PipWheelService } } - var sageAttention = transaction - .Settings.InstalledPackages.First(x => x.Id == installedPackage.Id) - .LaunchArgs?.FirstOrDefault(opt => opt.Name.Contains("sage-attention")); + var sageAttention = packageInSettings.LaunchArgs?.FirstOrDefault(opt => + opt.Name.Contains("sage-attention") + ); if (sageAttention is not null) { @@ -927,16 +1171,14 @@ await PipWheelService } else { - transaction - .Settings.InstalledPackages.First(x => x.Id == installedPackage.Id) - .LaunchArgs?.Add( - new LaunchOption - { - Name = "--use-sage-attention", - Type = LaunchOptionType.Bool, - OptionValue = true, - } - ); + packageInSettings.LaunchArgs?.Add( + new LaunchOption + { + Name = "--use-sage-attention", + Type = LaunchOptionType.Bool, + OptionValue = true, + } + ); } } @@ -979,21 +1221,23 @@ await PipWheelService .ConfigureAwait(false); } - private ImmutableDictionary GetEnvVars(ImmutableDictionary env) + private ImmutableDictionary GetEnvVars( + ImmutableDictionary env, + string installLocation, + InstalledPackage installedPackage + ) { // if we're not on windows or we don't have a windows rocm gpu, return original env - var hasRocmGpu = - SettingsManager.Settings.PreferredGpu?.IsWindowsRocmSupportedGpu() - ?? HardwareHelper.HasWindowsRocmSupportedGpu(); + var hasRocmGpu = HasWindowsRocmSupport(); if (!Compat.IsWindows || !hasRocmGpu) return env; - // set some experimental speed improving env vars for Windows ROCm - return env.SetItem("PYTORCH_TUNABLEOP_ENABLED", "1") - .SetItem("MIOPEN_FIND_MODE", "2") - .SetItem("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "1") - .SetItem("PYTORCH_ALLOC_CONF", "max_split_size_mb:6144,garbage_collection_threshold:0.8") // greatly helps prevent GPU OOM and instability/driver timeouts/OS hard locks and decreases dependency on Tiled VAE at standard res's - .SetItem("COMFYUI_ENABLE_MIOPEN", "1"); // re-enables "cudnn" in ComfyUI as it's needed for MiOpen to function properly + if (rocmPackageHelper is null) + return env; + + var rocmEnvironment = rocmPackageHelper.BuildLaunchEnvironment(ComfyWindowsRocmProfile.Profile); + + return env.SetItems(rocmEnvironment); } } diff --git a/StabilityMatrix.Core/Models/Packages/StableSwarm.cs b/StabilityMatrix.Core/Models/Packages/StableSwarm.cs index d71e52f5c..b76f2a561 100644 --- a/StabilityMatrix.Core/Models/Packages/StableSwarm.cs +++ b/StabilityMatrix.Core/Models/Packages/StableSwarm.cs @@ -10,9 +10,11 @@ using StabilityMatrix.Core.Models.FileInterfaces; using StabilityMatrix.Core.Models.Packages.Config; using StabilityMatrix.Core.Models.Progress; +using StabilityMatrix.Core.Models.Rocm; using StabilityMatrix.Core.Processes; using StabilityMatrix.Core.Python; using StabilityMatrix.Core.Services; +using StabilityMatrix.Core.Services.Rocm; namespace StabilityMatrix.Core.Models.Packages; @@ -23,7 +25,8 @@ public class StableSwarm( IDownloadService downloadService, IPrerequisiteHelper prerequisiteHelper, IPyInstallationManager pyInstallationManager, - IPipWheelService pipWheelService + IPipWheelService pipWheelService, + IRocmPackageHelper rocmPackageHelper ) : BaseGitPackage( githubApi, @@ -407,6 +410,7 @@ public override async Task RunPackage( } aspEnvVars.Update(settingsManager.Settings.EnvironmentVariables); + aspEnvVars.Update(BuildLinkedComfyLaunchEnvironment()); // Windows ROCm ComfyUI env var pass-through void HandleConsoleOutput(ProcessOutput s) { @@ -563,6 +567,50 @@ await prerequisiteHelper .ConfigureAwait(false); } + /// + /// Resolves the Comfy backend that Swarm is expected to self-launch. + /// + private InstalledPackage? TryResolveLinkedComfyBackend() + { + return settingsManager.Settings.InstalledPackages.FirstOrDefault(x => + x.PackageName is nameof(ComfyUI) or "ComfyUI-Zluda" + ); + } + + /// + /// Builds ROCm launch environment variables for Swarm so they flow through to its self-launched Comfy backend. + /// + private IReadOnlyDictionary BuildLinkedComfyLaunchEnvironment() + { + var comfyPackage = TryResolveLinkedComfyBackend(); + if (comfyPackage is null || !ShouldInjectLinkedComfyRocmEnvironment(comfyPackage)) + { + return new Dictionary(); + } + + return rocmPackageHelper.BuildLaunchEnvironment(ComfyWindowsRocmProfile.Profile); + } + + /// + /// Returns true only when the linked backend is standard ComfyUI on a supported Windows ROCm path. + /// + private bool ShouldInjectLinkedComfyRocmEnvironment(InstalledPackage comfyPackage) + { + if (!Compat.IsWindows || comfyPackage.PackageName != nameof(ComfyUI)) + { + return false; + } + + var compatibility = rocmPackageHelper.GetCompatibility(ComfyWindowsRocmProfile.Profile); + if (!compatibility.IsCompatible) + { + return false; + } + + var selectedTorchIndex = comfyPackage.PreferredTorchIndex ?? TorchIndex.Rocm; + return selectedTorchIndex == TorchIndex.Rocm; + } + private Task SetupModelFoldersConfig(DirectoryPath installDirectory) { var settingsPath = GetSettingsPath(installDirectory); diff --git a/StabilityMatrix.Core/Models/Packages/Wan2GP.cs b/StabilityMatrix.Core/Models/Packages/Wan2GP.cs index 2a00a6262..9150a61a6 100644 --- a/StabilityMatrix.Core/Models/Packages/Wan2GP.cs +++ b/StabilityMatrix.Core/Models/Packages/Wan2GP.cs @@ -6,9 +6,11 @@ using StabilityMatrix.Core.Helper.HardwareInfo; using StabilityMatrix.Core.Models.FileInterfaces; using StabilityMatrix.Core.Models.Progress; +using StabilityMatrix.Core.Models.Rocm; using StabilityMatrix.Core.Processes; using StabilityMatrix.Core.Python; using StabilityMatrix.Core.Services; +using StabilityMatrix.Core.Services.Rocm; namespace StabilityMatrix.Core.Models.Packages; @@ -30,7 +32,8 @@ public class Wan2GP( IDownloadService downloadService, IPrerequisiteHelper prerequisiteHelper, IPyInstallationManager pyInstallationManager, - IPipWheelService pipWheelService + IPipWheelService pipWheelService, + IRocmPackageHelper rocmPackageHelper ) : BaseGitPackage( githubApi, @@ -41,6 +44,12 @@ IPipWheelService pipWheelService pipWheelService ) { + private static readonly RocmPackageProfile WindowsRocmProfile = new() + { + UpgradePackages = true, + PostInstallPipArgs = ["hf-xet", "setuptools", "numpy==1.26.4"], + }; + public override string Name => "Wan2GP"; public override string DisplayName { get; set; } = "Wan2GP"; public override string Author => "deepbeepmeep"; @@ -64,7 +73,7 @@ IPipWheelService pipWheelService public override bool IsCompatible => HardwareHelper.HasNvidiaGpu() - || (Compat.IsWindows ? HardwareHelper.HasWindowsRocmSupportedGpu() : HardwareHelper.HasAmdGpu()); + || (Compat.IsWindows ? HasWindowsRocmSupport() : HardwareHelper.HasAmdGpu()); public override string MainBranch => "main"; public override bool ShouldIgnoreReleases => true; @@ -72,13 +81,13 @@ IPipWheelService pipWheelService public override Dictionary> SharedOutputFolders => new() { [SharedOutputType.Img2Vid] = ["outputs"] }; - // AMD ROCm requires Python 3.11, NVIDIA uses 3.10 + // Wan2GP currently uses Python 3.11 for ROCm and 3.10 for CUDA. public override PyVersion RecommendedPythonVersion => IsAmdRocm ? Python.PyInstallationManager.Python_3_11_13 : Python.PyInstallationManager.Python_3_10_17; public override string Disclaimer => IsAmdRocm && Compat.IsWindows - ? "AMD GPU support on Windows requires RX 7000 series or newer GPU" + ? "Windows AMD ROCm support is experimental. Please report any issues to Stability Matrix first so it can be determined whether the issue is package-specific.\nBecause this setup may not be officially supported by package developers, only contact upstream support for issues clearly caused by the package itself." : string.Empty; /// @@ -86,6 +95,21 @@ IPipWheelService pipWheelService /// private bool IsAmdRocm => GetRecommendedTorchVersion() == TorchIndex.Rocm; + private bool HasWindowsRocmSupport() + { + return GetWindowsRocmCompatibility().IsCompatible; + } + + private RocmCompatibilityResult GetWindowsRocmCompatibility() + { + if (!Compat.IsWindows) + { + return new RocmCompatibilityResult { IsCompatible = false }; + } + + return rocmPackageHelper.GetCompatibility(WindowsRocmProfile); + } + /// /// Python wrapper script that patches logging to also print to stdout/stderr, so /// StabilityMatrix can capture the output. Wan2GP logs through Gradio UI notifications @@ -210,13 +234,7 @@ public override TorchIndex GetRecommendedTorchVersion() { // Check for AMD ROCm support (Windows or Linux) var preferRocm = - ( - Compat.IsWindows - && ( - SettingsManager.Settings.PreferredGpu?.IsWindowsRocmSupportedGpu() - ?? HardwareHelper.HasWindowsRocmSupportedGpu() - ) - ) + (Compat.IsWindows && HasWindowsRocmSupport()) || ( Compat.IsLinux && (SettingsManager.Settings.PreferredGpu?.IsAmd ?? HardwareHelper.PreferRocm()) @@ -256,7 +274,15 @@ public override async Task InstallPackage( if (torchIndex == TorchIndex.Rocm) { - await InstallAmdRocmAsync(venvRunner, progress, onConsoleOutput).ConfigureAwait(false); + await InstallAmdRocmAsync( + venvRunner, + installLocation, + installedPackage, + progress, + onConsoleOutput, + cancellationToken + ) + .ConfigureAwait(false); } else { @@ -359,68 +385,46 @@ await venvRunner private async Task InstallAmdRocmAsync( IPyVenvRunner venvRunner, + string installLocation, + InstalledPackage installedPackage, IProgress? progress, - Action? onConsoleOutput + Action? onConsoleOutput, + CancellationToken cancellationToken ) { - progress?.Report(new ProgressReport(-1f, "Upgrading pip...", isIndeterminate: true)); - await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false); - if (Compat.IsWindows) { - // Windows AMD ROCm - special TheRock wheels - progress?.Report( - new ProgressReport(-1f, "Installing PyTorch ROCm wheels...", isIndeterminate: true) - ); - - // Set environment variable for wheel filename check bypass - venvRunner.UpdateEnvironmentVariables(env => env.SetItem("UV_SKIP_WHEEL_FILENAME_CHECK", "1")); - - // Install PyTorch ROCm wheels from TheRock releases (Python 3.11) - await venvRunner - .PipInstall( - "https://github.com/scottt/rocm-TheRock/releases/download/v6.5.0rc-pytorch-gfx110x/torch-2.7.0a0+rocm_git3f903c3-cp311-cp311-win_amd64.whl", - onConsoleOutput + await rocmPackageHelper + .InstallWindowsNativePackageAsync( + venvRunner, + installLocation, + installedPackage, + WindowsRocmProfile, + progress, + onConsoleOutput, + cancellationToken ) .ConfigureAwait(false); - await venvRunner - .PipInstall( - "https://github.com/scottt/rocm-TheRock/releases/download/v6.5.0rc-pytorch-gfx110x/torchaudio-2.7.0a0+52638ef-cp311-cp311-win_amd64.whl", - onConsoleOutput - ) - .ConfigureAwait(false); + return; + } - await venvRunner - .PipInstall( - "https://github.com/scottt/rocm-TheRock/releases/download/v6.5.0rc-pytorch-gfx110x/torchvision-0.22.0+9eb57cd-cp311-cp311-win_amd64.whl", - onConsoleOutput - ) - .ConfigureAwait(false); + progress?.Report(new ProgressReport(-1f, "Upgrading pip...", isIndeterminate: true)); + await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false); - // Install requirements directly using -r flag (handles @ URL syntax properly) - progress?.Report(new ProgressReport(-1f, "Installing requirements...", isIndeterminate: true)); - await venvRunner.PipInstall("-r requirements.txt", onConsoleOutput).ConfigureAwait(false); - } - else - { - // Linux AMD ROCm - standard PyTorch ROCm - // Install requirements directly using -r flag (handles @ URL syntax properly) - progress?.Report(new ProgressReport(-1f, "Installing requirements...", isIndeterminate: true)); - await venvRunner.PipInstall("-r requirements.txt", onConsoleOutput).ConfigureAwait(false); - - // Install torch with ROCm index (force reinstall to ensure correct version) - progress?.Report(new ProgressReport(-1f, "Installing PyTorch ROCm...", isIndeterminate: true)); - var torchArgs = new PipInstallArgs() - .WithTorch("==2.7.0") - .WithTorchVision("==0.22.0") - .WithTorchAudio("==2.7.0") - .WithTorchExtraIndex("rocm6.3") - .AddArg("--force-reinstall") - .AddArg("--no-deps"); - - await venvRunner.PipInstall(torchArgs, onConsoleOutput).ConfigureAwait(false); - } + progress?.Report(new ProgressReport(-1f, "Installing requirements...", isIndeterminate: true)); + await venvRunner.PipInstall("-r requirements.txt", onConsoleOutput).ConfigureAwait(false); + + progress?.Report(new ProgressReport(-1f, "Installing PyTorch ROCm...", isIndeterminate: true)); + var torchArgs = new PipInstallArgs() + .WithTorch() + .WithTorchVision() + .WithTorchAudio() + .WithTorchExtraIndex("rocm7.2") + .AddArg("--force-reinstall") + .AddArg("--no-deps"); + + await venvRunner.PipInstall(torchArgs, onConsoleOutput).ConfigureAwait(false); // Install additional packages await venvRunner.PipInstall("hf-xet setuptools numpy==1.26.4", onConsoleOutput).ConfigureAwait(false); @@ -437,6 +441,12 @@ public override async Task RunPackage( await SetupVenv(installLocation, pythonVersion: PyVersion.Parse(installedPackage.PythonVersion)) .ConfigureAwait(false); + if (Compat.IsWindows && HasWindowsRocmSupport()) + { + var rocmEnvironment = rocmPackageHelper.BuildLaunchEnvironment(WindowsRocmProfile); + VenvRunner.UpdateEnvironmentVariables(env => env.SetItems(rocmEnvironment)); + } + // Fix for distutils compatibility issue with Python 3.10 and setuptools VenvRunner.UpdateEnvironmentVariables(env => env.SetItem("SETUPTOOLS_USE_DISTUTILS", "stdlib")); diff --git a/StabilityMatrix.Core/Models/Rocm/ComfyWindowsRocmProfile.cs b/StabilityMatrix.Core/Models/Rocm/ComfyWindowsRocmProfile.cs new file mode 100644 index 000000000..ea4035ec6 --- /dev/null +++ b/StabilityMatrix.Core/Models/Rocm/ComfyWindowsRocmProfile.cs @@ -0,0 +1,23 @@ +namespace StabilityMatrix.Core.Models.Rocm; + +/// +/// Shared Windows ROCm profile for Comfy backends launched either directly by Stability Matrix or indirectly via SwarmUI. +/// +public static class ComfyWindowsRocmProfile +{ + public static RocmPackageProfile Profile { get; } = + new() + { + ExtraInstallPipArgs = ["numpy<2"], + PostInstallPipArgs = ["typing-extensions>=4.15.0"], + UpgradePackages = true, + ExtraEnvironmentFactory = BuildEnvironment, + }; + + private static IReadOnlyDictionary BuildEnvironment(RocmRuntimeContext runtimeContext) + { + return WindowsRocmSupport.IsModernArchitecture(runtimeContext.RuntimeGfxArch) + ? new Dictionary { ["COMFYUI_ENABLE_MIOPEN"] = "1" } + : new Dictionary(); + } +} diff --git a/StabilityMatrix.Core/Models/Rocm/RocmCompatibilityResult.cs b/StabilityMatrix.Core/Models/Rocm/RocmCompatibilityResult.cs new file mode 100644 index 000000000..401f3ada4 --- /dev/null +++ b/StabilityMatrix.Core/Models/Rocm/RocmCompatibilityResult.cs @@ -0,0 +1,17 @@ +using StabilityMatrix.Core.Helper.HardwareInfo; + +namespace StabilityMatrix.Core.Models.Rocm; + +/// +/// Describes whether a package/profile is currently compatible with ROCm on the active machine. +/// +public class RocmCompatibilityResult +{ + public bool IsCompatible { get; init; } + + public string? FailureReason { get; init; } + + public GpuInfo? SelectedGpu { get; init; } + + public string? ResolvedGfxArch { get; init; } +} diff --git a/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs b/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs new file mode 100644 index 000000000..88216ec12 --- /dev/null +++ b/StabilityMatrix.Core/Models/Rocm/RocmEnvironmentOptions.cs @@ -0,0 +1,52 @@ +namespace StabilityMatrix.Core.Models.Rocm; + +/// +/// Controls how ROCm helper defaults, package-specific variables, and user overrides are layered at launch. +/// +public class RocmEnvironmentOptions +{ + /// + /// When true, package-specific environment additions may be merged on top of helper defaults. + /// + public bool IncludePackageOverrides { get; init; } = true; + + /// + /// When true, user-defined Stability Matrix environment variables may override helper/package defaults last. + /// + public bool IncludeUserOverrides { get; init; } = true; + + /// + /// When set, overrides the default PyTorch allocator tuning string added by the ROCm helper. + /// + public string? PyTorchAllocConf { get; init; } = "max_split_size_mb:512,garbage_collection_threshold:0.8"; + + /// + /// When set, configures MIOpen find mode for helper-managed ROCm defaults. + /// + public string? MiopenFindMode { get; init; } = "2"; + + /// + /// When set, configures MIOpen search cutoff for helper-managed ROCm defaults. + /// + public string? MiopenSearchCutoff { get; init; } = "1"; + + /// + /// When set, configures MIOpen find enforcement behavior for helper-managed ROCm defaults. + /// + public string? MiopenFindEnforce { get; init; } = "1"; + + /// + /// When set, controls whether AMD Triton-backed flash attention is enabled by helper defaults. + /// + public string? FlashAttentionTritonAmdEnable { get; init; } = "TRUE"; + + /// + /// When true, helper-managed defaults will enable ROCm AOTriton on modern Windows ROCm architectures. + /// + public bool ApplyAotritonExperimental { get; init; } = true; + + /// + /// When true, helper-managed defaults will force math SDP on legacy ROCm architectures. + /// + public bool ApplyLegacySdpFallback { get; init; } = true; +} diff --git a/StabilityMatrix.Core/Models/Rocm/RocmInstallContext.cs b/StabilityMatrix.Core/Models/Rocm/RocmInstallContext.cs new file mode 100644 index 000000000..daaeb8499 --- /dev/null +++ b/StabilityMatrix.Core/Models/Rocm/RocmInstallContext.cs @@ -0,0 +1,11 @@ +namespace StabilityMatrix.Core.Models.Rocm; + +/// +/// Captures ROCm-related facts needed during package install or update flows. +/// +public class RocmInstallContext +{ + public string? RuntimeGfxArch { get; init; } + + public string? MultiArchDeviceExtra { get; init; } +} diff --git a/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs b/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs new file mode 100644 index 000000000..cd28db7a8 --- /dev/null +++ b/StabilityMatrix.Core/Models/Rocm/RocmPackageProfile.cs @@ -0,0 +1,53 @@ +using StabilityMatrix.Core.Models.Progress; + +namespace StabilityMatrix.Core.Models.Rocm; + +/// +/// Declares what a package expects from the ROCm helper. +/// Package classes should describe intent here rather than hardcoding ROCm decisions inline. +/// +public class RocmPackageProfile +{ + /// + /// Requirement files to install after helper-owned ROCm torch installation completes. + /// + public IEnumerable RequirementsFilePaths { get; init; } = ["requirements.txt"]; + + /// + /// Package requirement entries to exclude because the helper installs them from the ROCm multi-arch feed. + /// + public string RequirementsExcludePattern { get; init; } = @"(torch(vision|audio)?|xformers)([^a-z].*)?"; + + /// + /// Extra package-specific pip arguments to include when installing requirements before the helper-managed torch step. + /// + public IEnumerable ExtraInstallPipArgs { get; init; } = []; + + /// + /// Extra package-specific pip arguments to install after requirements and torch are complete. + /// + public IEnumerable PostInstallPipArgs { get; init; } = []; + + /// + /// When true, helper-managed requirements installs should use --upgrade. + /// + public bool UpgradePackages { get; init; } + + /// + /// When true, helper-managed torch installs should force reinstall the selected ROCm wheel set. + /// + public bool ForceReinstallTorch { get; init; } = true; + + /// + /// Optional callback for package-specific environment variables derived from a resolved ROCm context. + /// + public Func< + RocmRuntimeContext, + IReadOnlyDictionary + >? ExtraEnvironmentFactory { get; init; } + + /// + /// Controls whether package-specific environment variables should be layered on top of helper defaults. + /// + public RocmEnvironmentOptions EnvironmentOptions { get; init; } = new(); +} diff --git a/StabilityMatrix.Core/Models/Rocm/RocmRuntimeContext.cs b/StabilityMatrix.Core/Models/Rocm/RocmRuntimeContext.cs new file mode 100644 index 000000000..1fdda7914 --- /dev/null +++ b/StabilityMatrix.Core/Models/Rocm/RocmRuntimeContext.cs @@ -0,0 +1,18 @@ +using StabilityMatrix.Core.Helper.HardwareInfo; + +namespace StabilityMatrix.Core.Models.Rocm; + +/// +/// Captures resolved ROCm facts for a package launch or runtime decision. +/// This model is intended to separate hardware/runtime facts from package policy. +/// +public class RocmRuntimeContext +{ + public bool IsSupported { get; init; } + + public string? FailureReason { get; init; } + + public GpuInfo? SelectedGpu { get; init; } + + public string? RuntimeGfxArch { get; init; } +} diff --git a/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs b/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs new file mode 100644 index 000000000..b6533cab1 --- /dev/null +++ b/StabilityMatrix.Core/Models/Rocm/WindowsRocmSupport.cs @@ -0,0 +1,84 @@ +using StabilityMatrix.Core.Helper.HardwareInfo; + +namespace StabilityMatrix.Core.Models.Rocm; + +/// +/// Centralizes Windows ROCm support and architecture policy so hardware detection, package selection, +/// installation, and shared launch decisions use the same support map. +/// +public static class WindowsRocmSupport +{ + public const string MultiArchPythonPackageIndexUrl = + "https://rocm.nightlies.amd.com/whl-staging-multi-arch/"; + + // Used to exclude modern gfxarches from AOTriton activation EnVar as AOTriton does not currently support them. + // This is a temporary measure until AOTriton adds support for these architectures. + private static readonly HashSet AotritonExperimentalExcludedArchitectures = + [ + "gfx1152", + "gfx1153", + ]; + + public static bool IsSupportedGpu(GpuInfo? gpu) + { + if (gpu is null || !gpu.IsAmd || string.IsNullOrWhiteSpace(gpu.Name)) + return false; + + return IsSupportedArchitecture(gpu.GetAmdGfxArch()); + } + + public static bool IsSupportedArchitecture(string? gfxArch) + { + return TryGetCanonicalArchitecture(gfxArch) is not null; + } + + public static bool IsModernArchitecture(string? gfxArch) + { + return gfxArch?.StartsWith("gfx110", StringComparison.OrdinalIgnoreCase) == true + || gfxArch?.StartsWith("gfx115", StringComparison.OrdinalIgnoreCase) == true + || gfxArch?.StartsWith("gfx120", StringComparison.OrdinalIgnoreCase) == true; + } + + public static bool SupportsAotritonExperimental(string? gfxArch) + { + var canonicalArch = TryGetCanonicalArchitecture(gfxArch); + return canonicalArch is not null + && IsModernArchitecture(canonicalArch) + && !AotritonExperimentalExcludedArchitectures.Contains(canonicalArch); + } + + public static bool IsLegacyArchitecture(string? gfxArch) + { + return IsSupportedArchitecture(gfxArch) && !IsModernArchitecture(gfxArch); + } + + public static bool PreferLegacyAttentionFallback(string? gfxArch) + { + return IsLegacyArchitecture(gfxArch); + } + + public static string? TryGetCanonicalArchitecture(string? gfxArch) + { + if (string.IsNullOrWhiteSpace(gfxArch)) + return null; + + var normalizedArch = gfxArch.ToLowerInvariant(); + + return normalizedArch switch + { + "gfx900" or "gfx906" or "gfx1150" or "gfx1151" or "gfx1152" or "gfx1153" => normalizedArch, + var s + when s.StartsWith("gfx101", StringComparison.Ordinal) + || s.StartsWith("gfx103", StringComparison.Ordinal) + || s.StartsWith("gfx110", StringComparison.Ordinal) + || s.StartsWith("gfx120", StringComparison.Ordinal) => normalizedArch, + _ => null, + }; + } + + public static string? TryGetMultiArchDeviceExtra(string? gfxArch) + { + var canonicalArch = TryGetCanonicalArchitecture(gfxArch); + return canonicalArch is null ? null : $"device-{canonicalArch}"; + } +} diff --git a/StabilityMatrix.Core/Python/IPyVenvRunner.cs b/StabilityMatrix.Core/Python/IPyVenvRunner.cs index 6c1600b77..1b36418b3 100644 --- a/StabilityMatrix.Core/Python/IPyVenvRunner.cs +++ b/StabilityMatrix.Core/Python/IPyVenvRunner.cs @@ -91,7 +91,11 @@ Task Setup( /// /// Run a pip index command, return result as PipIndexResult. /// - Task PipIndex(string packageName, string? indexUrl = null); + Task PipIndex( + string packageName, + string? indexUrl = null, + bool includePrerelease = false + ); /// /// Run a custom install command. Waits for the process to exit. diff --git a/StabilityMatrix.Core/Python/PyVenvRunner.cs b/StabilityMatrix.Core/Python/PyVenvRunner.cs index 85fa06973..ff283fd99 100644 --- a/StabilityMatrix.Core/Python/PyVenvRunner.cs +++ b/StabilityMatrix.Core/Python/PyVenvRunner.cs @@ -332,10 +332,9 @@ public async Task> PipList() StringSplitOptions.TrimEntries | StringSplitOptions.RemoveEmptyEntries ) .Select(line => line.Trim()) - .FirstOrDefault( - line => - line.StartsWith("[", StringComparison.OrdinalIgnoreCase) - && line.EndsWith("]", StringComparison.OrdinalIgnoreCase) + .FirstOrDefault(line => + line.StartsWith("[", StringComparison.OrdinalIgnoreCase) + && line.EndsWith("]", StringComparison.OrdinalIgnoreCase) ); if (jsonLine is null) @@ -370,6 +369,17 @@ public async Task> PipList() ) .ConfigureAwait(false); + var packageNotFound = + result.StandardOutput?.Contains("Package(s) not found", StringComparison.OrdinalIgnoreCase) + == true + || result.StandardError?.Contains("Package(s) not found", StringComparison.OrdinalIgnoreCase) + == true; + + if (packageNotFound) + { + return null; + } + // Check return code if (result.ExitCode != 0) { @@ -378,9 +388,11 @@ public async Task> PipList() ); } - if (result.StandardOutput!.StartsWith("WARNING: Package(s) not found:")) + if (string.IsNullOrWhiteSpace(result.StandardOutput)) { - return null; + throw new ProcessException( + $"pip show returned no output for package '{packageName}': {result.StandardError}" + ); } return PipShowResult.Parse(result.StandardOutput); @@ -389,7 +401,11 @@ public async Task> PipList() /// /// Run a pip index command, return result as PipIndexResult. /// - public async Task PipIndex(string packageName, string? indexUrl = null) + public async Task PipIndex( + string packageName, + string? indexUrl = null, + bool includePrerelease = false + ) { if (!File.Exists(PipPath)) { @@ -413,10 +429,30 @@ public async Task> PipList() args = args.AddKeyedArgs("--index-url", ["--index-url", indexUrl]); } + if (includePrerelease) + { + args = args.AddArg("--pre"); + } + var result = await ProcessRunner .GetProcessResultAsync(PythonPath, args, WorkingDirectory?.FullPath, EnvironmentVariables) .ConfigureAwait(false); + var noMatchingDistribution = + result.StandardOutput?.Contains( + "No matching distribution found", + StringComparison.OrdinalIgnoreCase + ) == true + || result.StandardError?.Contains( + "No matching distribution found", + StringComparison.OrdinalIgnoreCase + ) == true; + + if (noMatchingDistribution || string.IsNullOrWhiteSpace(result.StandardOutput)) + { + return null; + } + // Check return code if (result.ExitCode != 0) { @@ -425,16 +461,6 @@ public async Task> PipList() ); } - if ( - string.IsNullOrEmpty(result.StandardOutput) - || result - .StandardOutput!.SplitLines() - .Any(l => l.StartsWith("ERROR: No matching distribution found")) - ) - { - return null; - } - return PipIndexResult.Parse(result.StandardOutput); } @@ -617,11 +643,11 @@ public void RunDetached( { // ReSharper disable once StringLiteralTypo var code = $""" - from importlib.metadata import entry_points - - results = entry_points(group='console_scripts', name='{entryPointName}') - print(tuple(results)[0].value, end='') - """; + from importlib.metadata import entry_points + + results = entry_points(group='console_scripts', name='{entryPointName}') + print(tuple(results)[0].value, end='') + """; var result = await Run($"-c \"{code}\"").ConfigureAwait(false); if (result.ExitCode == 0 && !string.IsNullOrWhiteSpace(result.StandardOutput)) diff --git a/StabilityMatrix.Core/Python/UvVenvRunner.cs b/StabilityMatrix.Core/Python/UvVenvRunner.cs index 53a295ab5..6fa69fd6e 100644 --- a/StabilityMatrix.Core/Python/UvVenvRunner.cs +++ b/StabilityMatrix.Core/Python/UvVenvRunner.cs @@ -386,6 +386,17 @@ public async Task> PipList() ) .ConfigureAwait(false); + var packageNotFound = + result.StandardOutput?.Contains("Package(s) not found", StringComparison.OrdinalIgnoreCase) + == true + || result.StandardError?.Contains("Package(s) not found", StringComparison.OrdinalIgnoreCase) + == true; + + if (packageNotFound) + { + return null; + } + // Check return code if (result.ExitCode != 0) { @@ -394,9 +405,11 @@ public async Task> PipList() ); } - if (result.StandardOutput!.StartsWith("WARNING: Package(s) not found:")) + if (string.IsNullOrWhiteSpace(result.StandardOutput)) { - return null; + throw new ProcessException( + $"pip show returned no output for package '{packageName}': {result.StandardError}" + ); } return PipShowResult.Parse(result.StandardOutput); @@ -405,7 +418,11 @@ public async Task> PipList() /// /// Run a pip index command, return result as PipIndexResult. /// - public async Task PipIndex(string packageName, string? indexUrl = null) + public async Task PipIndex( + string packageName, + string? indexUrl = null, + bool includePrerelease = false + ) { if (!File.Exists(PipPath)) { @@ -429,10 +446,30 @@ public async Task> PipList() args = args.AddKeyedArgs("--index-url", ["--index-url", indexUrl]); } + if (includePrerelease) + { + args = args.AddArg("--pre"); + } + var result = await ProcessRunner .GetProcessResultAsync(PythonPath, args, WorkingDirectory?.FullPath, EnvironmentVariables) .ConfigureAwait(false); + var noMatchingDistribution = + result.StandardOutput?.Contains( + "No matching distribution found", + StringComparison.OrdinalIgnoreCase + ) == true + || result.StandardError?.Contains( + "No matching distribution found", + StringComparison.OrdinalIgnoreCase + ) == true; + + if (noMatchingDistribution || string.IsNullOrWhiteSpace(result.StandardOutput)) + { + return null; + } + // Check return code if (result.ExitCode != 0) { @@ -441,16 +478,6 @@ public async Task> PipList() ); } - if ( - string.IsNullOrEmpty(result.StandardOutput) - || result - .StandardOutput!.SplitLines() - .Any(l => l.StartsWith("ERROR: No matching distribution found")) - ) - { - return null; - } - return PipIndexResult.Parse(result.StandardOutput); } diff --git a/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs new file mode 100644 index 000000000..167bac9ac --- /dev/null +++ b/StabilityMatrix.Core/Services/Rocm/IRocmPackageHelper.cs @@ -0,0 +1,52 @@ +using StabilityMatrix.Core.Models; +using StabilityMatrix.Core.Models.Progress; +using StabilityMatrix.Core.Models.Rocm; +using StabilityMatrix.Core.Processes; +using StabilityMatrix.Core.Python; + +namespace StabilityMatrix.Core.Services.Rocm; + +/// +/// Defines the ROCm helper surface area shared by ROCm-capable packages. +/// +public interface IRocmPackageHelper +{ + /// + /// Evaluates whether the current machine and package profile are compatible with ROCm. + /// + RocmCompatibilityResult GetCompatibility(RocmPackageProfile profile); + + /// + /// Builds a launch-time environment dictionary from resolved ROCm runtime data. + /// + IReadOnlyDictionary BuildLaunchEnvironment(RocmPackageProfile profile); + + /// + /// Returns shared Windows ROCm launch notice lines for helper-managed packages. + /// + IReadOnlyList GetWindowsLaunchNoticeLines(); + + /// + /// Ensures a usable Windows ROCm SDK devel package is installed from the ROCm multi-arch index, + /// preferring the same nightly build date as the installed torch build and falling back to the latest available build. + /// + Task EnsureWindowsSdkDevelAsync( + IPyVenvRunner venvRunner, + IProgress? progress = null, + Action? onConsoleOutput = null, + CancellationToken cancellationToken = default + ); + + /// + /// Performs the Windows-native ROCm install flow for a package using helper-resolved multi-arch device extras. + /// + Task InstallWindowsNativePackageAsync( + IPyVenvRunner venvRunner, + string installLocation, + InstalledPackage installedPackage, + RocmPackageProfile profile, + IProgress? progress = null, + Action? onConsoleOutput = null, + CancellationToken cancellationToken = default + ); +} diff --git a/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs new file mode 100644 index 000000000..ba4c1e241 --- /dev/null +++ b/StabilityMatrix.Core/Services/Rocm/RocmPackageHelper.cs @@ -0,0 +1,702 @@ +using System.Text.Json; +using Injectio.Attributes; +using NLog; +using StabilityMatrix.Core.Exceptions; +using StabilityMatrix.Core.Helper; +using StabilityMatrix.Core.Helper.HardwareInfo; +using StabilityMatrix.Core.Models; +using StabilityMatrix.Core.Models.FileInterfaces; +using StabilityMatrix.Core.Models.Progress; +using StabilityMatrix.Core.Models.Rocm; +using StabilityMatrix.Core.Processes; +using StabilityMatrix.Core.Python; +using StabilityMatrix.Core.Services; + +namespace StabilityMatrix.Core.Services.Rocm; + +/// +/// Provides the shared ROCm helper surface area used by ROCm-capable packages. +/// +[RegisterSingleton] +public class RocmPackageHelper(ISettingsManager settingsManager) : IRocmPackageHelper +{ + private static readonly Logger Logger = LogManager.GetCurrentClassLogger(); + private static readonly StringComparer EnvComparer = StringComparer.OrdinalIgnoreCase; + private const string RocmSdkDevelPackageName = "rocm-sdk-devel"; + private static readonly string[] WindowsLaunchNoticeLines = + [ + "Stability Matrix Windows ROCm Notice: Windows AMD ROCm support is experimental. Please report any issues to Stability Matrix first so it can be determined whether the issue is package-specific.", + "Because this setup may not be officially supported by package developers, only contact upstream support for issues clearly caused by the package itself.", + ]; + + /// + /// Evaluates the current Windows machine state for the given package profile and returns the resolved ROCm compatibility result. + /// + public RocmCompatibilityResult GetCompatibility(RocmPackageProfile profile) + { + _ = profile; + return BuildCompatibilityResult(profile); + } + + /// + /// Resolves launch-time ROCm runtime details from the current Windows machine state. + /// This is used to build helper-managed environment variables for package launch. + /// + private RocmRuntimeContext ResolveRuntimeContext(RocmPackageProfile profile) + { + _ = profile; + + var state = ResolveWindowsMachineState(); + if (!state.IsCompatible) + { + return new RocmRuntimeContext + { + IsSupported = false, + FailureReason = state.FailureReason, + SelectedGpu = state.SelectedGpu, + RuntimeGfxArch = state.RuntimeGfxArch, + }; + } + + return new RocmRuntimeContext + { + IsSupported = true, + SelectedGpu = state.SelectedGpu, + RuntimeGfxArch = state.RuntimeGfxArch, + }; + } + + /// + /// Resolves install-time ROCm package selection details from the current Windows machine state. + /// This includes the canonical runtime GFX architecture and the matching multi-arch device extra. + /// + private RocmInstallContext ResolveInstallContext(RocmPackageProfile profile) + { + _ = profile; + + var state = ResolveWindowsMachineState(); + + return new RocmInstallContext + { + RuntimeGfxArch = state.RuntimeGfxArch, + MultiArchDeviceExtra = state.MultiArchDeviceExtra, + }; + } + + /// + /// Builds the final launch environment for a ROCm-capable package by combining helper defaults, + /// package-specific environment values, and optional user overrides. + /// + public IReadOnlyDictionary BuildLaunchEnvironment(RocmPackageProfile profile) + { + var runtimeContext = ResolveRuntimeContext(profile); + + if (!runtimeContext.IsSupported) + return new Dictionary(); + + var helperEnvironment = BuildHelperLaunchEnvironment(runtimeContext, profile); + var packageEnvironment = + profile.ExtraEnvironmentFactory?.Invoke(runtimeContext) ?? new Dictionary(); + + var mergedEnvironment = MergeLaunchEnvironment( + helperEnvironment, + packageEnvironment, + profile.EnvironmentOptions + ); + + return mergedEnvironment; + } + + /// + /// Returns the shared informational notice lines shown when launching Windows ROCm packages. + /// + public IReadOnlyList GetWindowsLaunchNoticeLines() + { + return WindowsLaunchNoticeLines; + } + + /// + /// Ensures rocm-sdk-devel is installed from the ROCm multi-arch index. + /// It prefers a build whose nightly date matches the installed ROCm torch build and falls back to the latest available build when no exact match is available. + /// + public async Task EnsureWindowsSdkDevelAsync( + IPyVenvRunner venvRunner, + IProgress? progress = null, + Action? onConsoleOutput = null, + CancellationToken cancellationToken = default + ) + { + var torchInfo = await venvRunner.PipShow("torch").ConfigureAwait(false); + if (torchInfo is null) + { + throw new InvalidOperationException( + "torch is not installed in this environment. Install the Windows ROCm torch build first." + ); + } + + if (!IsUsableWindowsNativeTorchBuild(torchInfo.Version, null)) + { + throw new InvalidOperationException( + $"Installed torch is not a usable Windows ROCm build (detected version: {torchInfo.Version})." + ); + } + + var nightlyBuildDateToken = TryGetNightlyBuildDateToken(torchInfo.Version); + var installedRocmSdkDevel = await venvRunner.PipShow(RocmSdkDevelPackageName).ConfigureAwait(false); + if ( + !string.IsNullOrWhiteSpace(nightlyBuildDateToken) + && HasNightlyBuildDateToken(installedRocmSdkDevel?.Version, nightlyBuildDateToken) + ) + { + return; + } + + var indexResult = await venvRunner + .PipIndex( + RocmSdkDevelPackageName, + WindowsRocmSupport.MultiArchPythonPackageIndexUrl, + includePrerelease: true + ) + .ConfigureAwait(false); + + var latestVersion = indexResult?.AvailableVersions.FirstOrDefault(); + var matchingVersion = string.IsNullOrWhiteSpace(nightlyBuildDateToken) + ? null + : indexResult?.AvailableVersions.FirstOrDefault(version => + HasNightlyBuildDateToken(version, nightlyBuildDateToken) + ); + var versionToInstall = matchingVersion ?? latestVersion; + + if (string.IsNullOrWhiteSpace(versionToInstall)) + { + throw new InvalidOperationException( + $"No {RocmSdkDevelPackageName} builds were found on the ROCm multi-arch index." + ); + } + + if (!string.IsNullOrWhiteSpace(matchingVersion)) + { + progress?.Report( + new ProgressReport( + -1f, + $"Installing {RocmSdkDevelPackageName} {matchingVersion} for Windows ROCm...", + isIndeterminate: true + ) + ); + } + else + { + progress?.Report( + new ProgressReport( + -1f, + $"Falling back to latest available {RocmSdkDevelPackageName} build {versionToInstall} for Windows ROCm...", + isIndeterminate: true + ) + ); + } + + await venvRunner + .PipInstall( + new PipInstallArgs() + .AddArg("--upgrade") + .AddKeyedArgs( + "--index-url", + ["--index-url", WindowsRocmSupport.MultiArchPythonPackageIndexUrl] + ) + .AddArg($"{RocmSdkDevelPackageName}=={versionToInstall}"), + onConsoleOutput + ) + .ConfigureAwait(false); + + _ = cancellationToken; + } + + /// + /// Performs the shared Windows-native ROCm install flow for helper-managed packages. + /// This installs package requirements, the ROCm torch wheel set from the multi-arch index, + /// and then verifies that the resulting torch installation reports usable ROCm metadata. + /// + public async Task InstallWindowsNativePackageAsync( + IPyVenvRunner venvRunner, + string installLocation, + InstalledPackage installedPackage, + RocmPackageProfile profile, + IProgress? progress = null, + Action? onConsoleOutput = null, + CancellationToken cancellationToken = default + ) + { + var compatibility = GetCompatibility(profile); + if (!compatibility.IsCompatible) + { + throw new InvalidOperationException( + compatibility.FailureReason + ?? "Windows ROCm installation is not supported for the current machine." + ); + } + + var installContext = ResolveInstallContext(profile); + + var multiArchDeviceExtra = installContext.MultiArchDeviceExtra; + + if (string.IsNullOrWhiteSpace(multiArchDeviceExtra)) + { + throw new ApplicationException( + $"No Windows ROCm multi-arch device extra is available for '{installContext.RuntimeGfxArch ?? "unknown"}'." + ); + } + + progress?.Report(new ProgressReport(-1f, "Upgrading pip...", isIndeterminate: true)); + await venvRunner.PipInstall("--upgrade pip wheel", onConsoleOutput).ConfigureAwait(false); + + progress?.Report( + new ProgressReport(-1f, "Installing package requirements...", isIndeterminate: true) + ); + + var requirementsPipArgs = new PipInstallArgs([.. profile.ExtraInstallPipArgs]); + if (profile.UpgradePackages) + { + requirementsPipArgs = requirementsPipArgs.AddArg("--upgrade"); + } + + foreach (var relativePath in profile.RequirementsFilePaths) + { + var requirementsFile = new FilePath(venvRunner.WorkingDirectory ?? installLocation, relativePath); + if (!requirementsFile.Exists) + continue; + + var requirementsContent = await requirementsFile + .ReadAllTextAsync(cancellationToken) + .ConfigureAwait(false); + + requirementsPipArgs = requirementsPipArgs.WithParsedFromRequirementsTxt( + requirementsContent, + profile.RequirementsExcludePattern + ); + } + + if (installedPackage.PipOverrides != null) + { + requirementsPipArgs = requirementsPipArgs.WithUserOverrides(installedPackage.PipOverrides); + } + + await venvRunner.PipInstall(requirementsPipArgs, onConsoleOutput).ConfigureAwait(false); + + progress?.Report(new ProgressReport(-1f, "Installing ROCm torch...", isIndeterminate: true)); + + var torchArgs = new PipInstallArgs() + .AddArg("--upgrade") + .AddKeyedArgs("--index-url", ["--index-url", WindowsRocmSupport.MultiArchPythonPackageIndexUrl]) + .AddArgs( + new Argument($"torch[{multiArchDeviceExtra}]"), + new Argument($"torchvision[{multiArchDeviceExtra}]"), + new Argument("torchaudio") + ); + + if (profile.ForceReinstallTorch) + { + torchArgs = torchArgs.AddArg("--force-reinstall"); + } + + if (installedPackage.PipOverrides != null) + { + torchArgs = torchArgs.WithUserOverrides(installedPackage.PipOverrides); + } + + await venvRunner.PipInstall(torchArgs, onConsoleOutput).ConfigureAwait(false); + if (profile.PostInstallPipArgs.Any()) + { + var postInstallPipArgs = new PipInstallArgs([.. profile.PostInstallPipArgs]); + if (installedPackage.PipOverrides != null) + { + postInstallPipArgs = postInstallPipArgs.WithUserOverrides(installedPackage.PipOverrides); + } + + await venvRunner.PipInstall(postInstallPipArgs, onConsoleOutput).ConfigureAwait(false); + } + + await VerifyWindowsNativeTorchInstallAsync(venvRunner, onConsoleOutput, cancellationToken) + .ConfigureAwait(false); + } + + /// + /// Builds a compatibility result from the current machine state and package profile. + /// This keeps the first ROCm helper slice focused on hardware capability and GPU selection only. + /// + private RocmCompatibilityResult BuildCompatibilityResult(RocmPackageProfile profile) + { + _ = profile; + var state = ResolveWindowsMachineState(); + + return new RocmCompatibilityResult + { + IsCompatible = state.IsCompatible, + FailureReason = state.FailureReason, + SelectedGpu = state.SelectedGpu, + ResolvedGfxArch = state.RuntimeGfxArch, + }; + } + + private ResolvedWindowsRocmState ResolveWindowsMachineState() + { + var amdGpus = GetAmdGpuCandidates(forceRefresh: true).ToList(); + if (amdGpus.Count == 0) + { + return new ResolvedWindowsRocmState + { + IsCompatible = false, + FailureReason = "No AMD GPU was detected for ROCm evaluation.", + }; + } + + var supportedAmdGpus = amdGpus.Where(IsSupportedWindowsRocmGpu).ToList(); + if (supportedAmdGpus.Count == 0) + { + return new ResolvedWindowsRocmState + { + IsCompatible = false, + FailureReason = GetUnsupportedGpuReason(amdGpus), + }; + } + + var selectedGpu = + TryResolvePreferredAmdGpu(supportedAmdGpus, settingsManager.Settings.PreferredGpu) + ?? supportedAmdGpus.First(); + var runtimeGfxArch = + WindowsRocmSupport.TryGetCanonicalArchitecture(selectedGpu.GetAmdGfxArch()) + ?? GetSupportedFallbackGfxArch(supportedAmdGpus); + var isCompatible = !string.IsNullOrWhiteSpace(runtimeGfxArch); + + return new ResolvedWindowsRocmState + { + IsCompatible = isCompatible, + FailureReason = isCompatible + ? null + : "No supported AMD GFX architecture could be resolved for ROCm.", + SelectedGpu = selectedGpu, + RuntimeGfxArch = runtimeGfxArch, + MultiArchDeviceExtra = WindowsRocmSupport.TryGetMultiArchDeviceExtra(runtimeGfxArch), + }; + } + + /// + /// Returns AMD GPUs from Stability Matrix's internal hardware model. + /// This is the canonical GPU source for the ROCm helper and intentionally avoids package-local probing. + /// + private static IReadOnlyList GetAmdGpuCandidates(bool forceRefresh = false) + { + return HardwareHelper.IterGpuInfo(forceRefresh).Where(gpu => gpu.IsAmd).ToList(); + } + + /// + /// Resolves the preferred AMD GPU when the configured preference is still present in the current hardware list. + /// + private static GpuInfo? TryResolvePreferredAmdGpu( + IEnumerable availableGpus, + GpuInfo? preferredGpu + ) + { + if (preferredGpu is null || !preferredGpu.IsAmd) + return null; + + var preferredMatch = availableGpus.FirstOrDefault(gpu => gpu.Equals(preferredGpu)); + if (preferredMatch is not null) + return preferredMatch; + + if (!string.IsNullOrWhiteSpace(preferredGpu.Name)) + { + Logger.Info( + "Preferred GPU {PreferredGpuName} was ignored for ROCm detection because it is not present in current hardware enumeration.", + preferredGpu.Name + ); + } + + return null; + } + + /// + /// Resolves the preferred AMD GFX architecture when the configured GPU is supported and currently present. + /// + private static string? TryResolvePreferredAmdGfxArch( + IEnumerable availableGpus, + GpuInfo? preferredGpu + ) + { + var resolvedPreferredGpu = TryResolvePreferredAmdGpu(availableGpus, preferredGpu); + return resolvedPreferredGpu is not null && IsSupportedWindowsRocmGpu(resolvedPreferredGpu) + ? WindowsRocmSupport.TryGetCanonicalArchitecture(resolvedPreferredGpu.GetAmdGfxArch()) + : null; + } + + /// + /// Resolves the first supported AMD GFX architecture from the current machine state when no preferred GPU applies. + /// + private static string? GetSupportedFallbackGfxArch(IEnumerable availableGpus) + { + return availableGpus + .Where(IsSupportedWindowsRocmGpu) + .Select(gpu => WindowsRocmSupport.TryGetCanonicalArchitecture(gpu.GetAmdGfxArch())) + .FirstOrDefault(IsSupportedWindowsRocmArchitecture); + } + + /// + /// Determines whether a GPU is supported by the Windows ROCm install flow currently modeled by the helper. + /// + private static bool IsSupportedWindowsRocmGpu(GpuInfo gpu) + { + return WindowsRocmSupport.IsSupportedGpu(gpu); + } + + /// + /// Determines whether a resolved AMD GFX architecture falls inside the Windows ROCm support set currently modeled by the helper. + /// + private static bool IsSupportedWindowsRocmArchitecture(string? gfxArch) + { + return WindowsRocmSupport.IsSupportedArchitecture(gfxArch); + } + + /// + /// Produces a readable incompatibility reason when AMD hardware is present but not usable for Windows ROCm. + /// + private static string GetUnsupportedGpuReason(IReadOnlyList amdGpus) + { + _ = amdGpus; + return "No AMD GPU with a supported Windows ROCm architecture was detected."; + } + + /// + /// Verifies that the installed torch build still reports usable ROCm metadata after helper-managed installs complete. + /// + private static async Task VerifyWindowsNativeTorchInstallAsync( + IPyVenvRunner venvRunner, + Action? onConsoleOutput, + CancellationToken cancellationToken + ) + { + var torchInfo = await venvRunner.PipShow("torch").ConfigureAwait(false); + if (torchInfo is null) + { + throw new ApplicationException("torch was not installed after Windows ROCm setup."); + } + + var verificationResult = await venvRunner + .Run( + "-c \"import json, torch; print(json.dumps({'version': torch.__version__, 'hip': torch.version.hip, 'cuda': torch.cuda.is_available()}))\"" + ) + .ConfigureAwait(false); + + var verificationOutput = (verificationResult.StandardOutput ?? string.Empty).Trim(); + if (string.IsNullOrWhiteSpace(verificationOutput)) + { + throw new ApplicationException("Torch verification produced no output."); + } + + var verificationJson = TryExtractJsonObject(verificationOutput); + if (string.IsNullOrWhiteSpace(verificationJson)) + { + throw new ApplicationException($"Unexpected torch verification output: {verificationOutput}"); + } + + JsonDocument verificationDocument; + try + { + verificationDocument = JsonDocument.Parse(verificationJson); + } + catch (Exception exception) + { + throw new ApplicationException( + $"Unexpected torch verification output: {verificationOutput}", + exception + ); + } + + using (verificationDocument) + { + var root = verificationDocument.RootElement; + var version = root.TryGetProperty("version", out var versionElement) + ? versionElement.GetString() + : null; + var hipVersion = root.TryGetProperty("hip", out var hipElement) ? hipElement.GetString() : null; + var cudaAvailable = root.TryGetProperty("cuda", out var cudaElement) && cudaElement.GetBoolean(); + + if (!IsUsableWindowsNativeTorchBuild(version, hipVersion)) + { + throw new ApplicationException( + $"Installed torch is not a usable ROCm build. Verification output: {verificationOutput}" + ); + } + + if (!cudaAvailable) + { + onConsoleOutput?.Invoke( + ProcessOutput.FromStdErrLine( + $"Torch verification warning: installed ROCm torch build reported cuda={cudaAvailable}; continuing because ROCm metadata was detected (version={version}, hip={hipVersion})." + ) + ); + } + + onConsoleOutput?.Invoke( + ProcessOutput.FromStdOutLine( + $"Torch verification: version={version}, hip={hipVersion}, cuda={cudaAvailable}" + ) + ); + } + + _ = cancellationToken; + } + + internal static bool IsUsableWindowsNativeTorchBuild(string? version, string? hipVersion) + { + if (!string.IsNullOrWhiteSpace(hipVersion)) + return true; + + return !string.IsNullOrWhiteSpace(version) + && version.Contains("rocm", StringComparison.OrdinalIgnoreCase); + } + + private static string? TryGetNightlyBuildDateToken(string? version) + { + if (string.IsNullOrWhiteSpace(version)) + return null; + + var devIndex = version.IndexOf("dev", StringComparison.OrdinalIgnoreCase); + if (devIndex < 0) + return null; + + var startIndex = devIndex + 3; + if (version.Length < startIndex + 8) + return null; + + var token = version.Substring(startIndex, 8); + return token.All(char.IsDigit) ? token : null; + } + + private static bool HasNightlyBuildDateToken(string? version, string nightlyBuildDateToken) + { + return !string.IsNullOrWhiteSpace(version) + && !string.IsNullOrWhiteSpace(nightlyBuildDateToken) + && version.Contains($"dev{nightlyBuildDateToken}", StringComparison.OrdinalIgnoreCase); + } + + internal static string? TryExtractJsonObject(string output) + { + if (string.IsNullOrWhiteSpace(output)) + return null; + + var trimmedOutput = output.Trim(); + + for (var index = 0; index < trimmedOutput.Length; index++) + { + if (trimmedOutput[index] != '{') + continue; + + try + { + using var document = JsonDocument.Parse(trimmedOutput[index..]); + return document.RootElement.GetRawText(); + } + catch (JsonException) { } + } + + return null; + } + + /// + /// Builds helper-owned ROCm launch variables from the resolved runtime context and package profile. + /// + private IReadOnlyDictionary BuildHelperLaunchEnvironment( + RocmRuntimeContext runtimeContext, + RocmPackageProfile profile + ) + { + var environment = new Dictionary(EnvComparer); + var options = profile.EnvironmentOptions; + var gfxArch = runtimeContext.RuntimeGfxArch; + + ApplyDefaultLaunchEnvironment(environment, gfxArch, options); + + return environment; + } + + private void ApplyDefaultLaunchEnvironment( + IDictionary environment, + string? gfxArch, + RocmEnvironmentOptions options + ) + { + SetIfNotNull(environment, "FLASH_ATTENTION_TRITON_AMD_ENABLE", options.FlashAttentionTritonAmdEnable); + SetIfNotNull(environment, "MIOPEN_FIND_MODE", options.MiopenFindMode); + SetIfNotNull(environment, "MIOPEN_SEARCH_CUTOFF", options.MiopenSearchCutoff); + SetIfNotNull(environment, "MIOPEN_FIND_ENFORCE", options.MiopenFindEnforce); + SetIfNotNull(environment, "PYTORCH_ALLOC_CONF", options.PyTorchAllocConf); + + if (options.ApplyAotritonExperimental && WindowsRocmSupport.SupportsAotritonExperimental(gfxArch)) + { + environment["TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL"] = "1"; + } + + if (options.ApplyLegacySdpFallback && WindowsRocmSupport.IsLegacyArchitecture(gfxArch)) + { + environment["TORCH_BACKENDS_CUDA_FLASH_SDP_ENABLED"] = "0"; + environment["TORCH_BACKENDS_CUDA_MEM_EFF_SDP_ENABLED"] = "0"; + environment["TORCH_BACKENDS_CUDA_MATH_SDP_ENABLED"] = "1"; + } + } + + private static void SetIfNotNull(IDictionary environment, string key, string? value) + { + if (!string.IsNullOrWhiteSpace(value)) + { + environment[key] = value; + } + } + + /// + /// Merges helper-owned and package-specific launch environment variables. + /// + private IReadOnlyDictionary MergeLaunchEnvironment( + IReadOnlyDictionary helperEnvironment, + IReadOnlyDictionary packageEnvironment, + RocmEnvironmentOptions options + ) + { + var merged = new Dictionary(EnvComparer); + + foreach (var source in new[] { helperEnvironment, packageEnvironment }) + { + if (ReferenceEquals(source, packageEnvironment) && !options.IncludePackageOverrides) + continue; + + foreach (var pair in source) + { + merged[pair.Key] = pair.Value; + } + } + + if ( + options.IncludeUserOverrides + && settingsManager.Settings.EnvironmentVariables is { Count: > 0 } userOverrides + ) + { + foreach (var pair in userOverrides) + { + merged[pair.Key] = pair.Value; + } + } + + return merged; + } + + private sealed class ResolvedWindowsRocmState + { + public bool IsCompatible { get; init; } + + public string? FailureReason { get; init; } + + public GpuInfo? SelectedGpu { get; init; } + + public string? RuntimeGfxArch { get; init; } + + public string? MultiArchDeviceExtra { get; init; } + } +} diff --git a/StabilityMatrix.Tests/Core/RocmPackageHelperTests.cs b/StabilityMatrix.Tests/Core/RocmPackageHelperTests.cs new file mode 100644 index 000000000..ad79a23b0 --- /dev/null +++ b/StabilityMatrix.Tests/Core/RocmPackageHelperTests.cs @@ -0,0 +1,116 @@ +using System.Text.Json; +using StabilityMatrix.Core.Helper; +using StabilityMatrix.Core.Helper.HardwareInfo; +using StabilityMatrix.Core.Models.Rocm; +using StabilityMatrix.Core.Services.Rocm; + +namespace StabilityMatrix.Tests.Core; + +[TestClass] +public class RocmPackageHelperTests +{ + [TestMethod] + public void WindowsRocmSupport_TryGetMultiArchDeviceExtra_ReturnsExpectedExtra_ForSupportedArch() + { + var deviceExtra = WindowsRocmSupport.TryGetMultiArchDeviceExtra("gfx1201"); + + Assert.AreEqual("device-gfx1201", deviceExtra); + } + + [TestMethod] + public void WindowsRocmSupport_TryGetMultiArchDeviceExtra_ReturnsExpectedExtra_ForCanonicalVega20Arch() + { + var deviceExtra = WindowsRocmSupport.TryGetMultiArchDeviceExtra("gfx906"); + + Assert.AreEqual("device-gfx906", deviceExtra); + } + + [TestMethod] + public void WindowsRocmSupport_TryGetCanonicalArchitecture_ReturnsCanonicalArch_WhenAlreadyCanonical() + { + var canonicalArch = WindowsRocmSupport.TryGetCanonicalArchitecture("gfx906"); + + Assert.AreEqual("gfx906", canonicalArch); + } + + [TestMethod] + public void IsUsableWindowsNativeTorchBuild_ReturnsTrue_WhenHipMetadataExists() + { + var isUsable = RocmPackageHelper.IsUsableWindowsNativeTorchBuild( + version: "test-version", + hipVersion: "test-hip-version" + ); + + Assert.IsTrue(isUsable); + } + + [TestMethod] + public void IsUsableWindowsNativeTorchBuild_ReturnsTrue_WhenVersionContainsRocm() + { + var isUsable = RocmPackageHelper.IsUsableWindowsNativeTorchBuild( + version: "test-version+rocm", + hipVersion: null + ); + + Assert.IsTrue(isUsable); + } + + [TestMethod] + public void IsUsableWindowsNativeTorchBuild_ReturnsFalse_WhenNoRocmMetadataExists() + { + var isUsable = RocmPackageHelper.IsUsableWindowsNativeTorchBuild( + version: "test-version", + hipVersion: null + ); + + Assert.IsFalse(isUsable); + } + + [TestMethod] + public void TryExtractJsonObject_ReturnsJson_WhenOutputContainsDiagnosticPrefix() + { + const string output = + "warning: ROCm topology probe emitted diagnostic output" + + "\nwarning: continuing with torch verification" + + "\n{\"version\": \"test-version\", \"hip\": \"test-hip-version\", \"cuda\": false}"; + + var json = RocmPackageHelper.TryExtractJsonObject(output); + + Assert.IsNotNull(json); + + using var document = JsonDocument.Parse(json); + var root = document.RootElement; + + Assert.AreEqual("test-version", root.GetProperty("version").GetString()); + Assert.AreEqual("test-hip-version", root.GetProperty("hip").GetString()); + Assert.IsFalse(root.GetProperty("cuda").GetBoolean()); + } + + [TestMethod] + public void TryExtractJsonObject_ReturnsNull_WhenOutputContainsNoJson() + { + const string output = + "warning: ROCm topology probe emitted diagnostic output\n" + + "warning: no JSON payload was produced"; + + var json = RocmPackageHelper.TryExtractJsonObject(output); + + Assert.IsNull(json); + } + + [TestMethod] + public void WindowsRocmSupport_TryGetMultiArchDeviceExtra_ReturnsExpectedExtra_ForKrakenPoint() + { + var deviceExtra = WindowsRocmSupport.TryGetMultiArchDeviceExtra("gfx1152"); + + Assert.AreEqual("device-gfx1152", deviceExtra); + } + + [TestMethod] + public void WindowsRocmSupport_IsSupportedGpu_ReturnsTrue_ForSupportedAmdGpu() + { + var gpu = new GpuInfo { Name = "AMD Radeon RX 9070 XT", MemoryBytes = 16UL * Size.GiB }; + + Assert.IsTrue(WindowsRocmSupport.IsSupportedGpu(gpu)); + } +} diff --git a/StabilityMatrix.Tests/Helper/PackageFactoryTests.cs b/StabilityMatrix.Tests/Helper/PackageFactoryTests.cs index f78027039..bf45d60c4 100644 --- a/StabilityMatrix.Tests/Helper/PackageFactoryTests.cs +++ b/StabilityMatrix.Tests/Helper/PackageFactoryTests.cs @@ -24,6 +24,8 @@ public void Setup() null!, null!, null!, + null!, + null!, null! ); } diff --git a/StabilityMatrix.Tests/Models/Packages/PackageHelper.cs b/StabilityMatrix.Tests/Models/Packages/PackageHelper.cs index b165031d7..8e2106908 100644 --- a/StabilityMatrix.Tests/Models/Packages/PackageHelper.cs +++ b/StabilityMatrix.Tests/Models/Packages/PackageHelper.cs @@ -6,6 +6,7 @@ using StabilityMatrix.Core.Models.Packages; using StabilityMatrix.Core.Python; using StabilityMatrix.Core.Services; +using StabilityMatrix.Core.Services.Rocm; namespace StabilityMatrix.Tests.Models.Packages; @@ -24,7 +25,8 @@ public static IEnumerable GetPackages() .AddSingleton(Substitute.For()) .AddSingleton(Substitute.For()) .AddSingleton(Substitute.For()) - .AddSingleton(Substitute.For()); + .AddSingleton(Substitute.For()) + .AddSingleton(Substitute.For()); var assembly = typeof(BasePackage).Assembly; var packageTypes = assembly