|
1 | | -using System.Text.Json; |
| 1 | +using System.Collections.Immutable; |
| 2 | +using System.Text.Json; |
2 | 3 | using System.Text.RegularExpressions; |
3 | 4 | using Injectio.Attributes; |
4 | 5 | using NLog; |
@@ -443,6 +444,8 @@ public override async Task RunPackage( |
443 | 444 | await SetupVenv(installLocation, pythonVersion: PyVersion.Parse(installedPackage.PythonVersion)) |
444 | 445 | .ConfigureAwait(false); |
445 | 446 |
|
| 447 | + VenvRunner.UpdateEnvironmentVariables(GetEnvVars); |
| 448 | + |
446 | 449 | VenvRunner.RunDetached( |
447 | 450 | [Path.Combine(installLocation, options.Command ?? LaunchCommand), .. options.Arguments], |
448 | 451 | HandleConsoleOutput, |
@@ -808,4 +811,20 @@ private async Task InstallNunchaku(InstalledPackage? installedPackage) |
808 | 811 | EventManager.Instance.OnPackageInstallProgressAdded(runner); |
809 | 812 | await runner.ExecuteSteps([installNunchaku]).ConfigureAwait(false); |
810 | 813 | } |
| 814 | + |
| 815 | + private ImmutableDictionary<string, string> GetEnvVars(ImmutableDictionary<string, string> env) |
| 816 | + { |
| 817 | + // if we're not on windows or we don't have a windows rocm gpu, return original env |
| 818 | + var hasRocmGpu = |
| 819 | + SettingsManager.Settings.PreferredGpu?.IsWindowsRocmSupportedGpu() |
| 820 | + ?? HardwareHelper.HasWindowsRocmSupportedGpu(); |
| 821 | + |
| 822 | + if (!Compat.IsWindows || !hasRocmGpu) |
| 823 | + return env; |
| 824 | + |
| 825 | + // set some experimental speed improving env vars for Windows ROCm |
| 826 | + return env.SetItem("PYTORCH_TUNABLEOP_ENABLED", "1") |
| 827 | + .SetItem("MIOPEN_FIND_MODE", "2") |
| 828 | + .SetItem("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "1"); |
| 829 | + } |
811 | 830 | } |
0 commit comments