Skip to content

Commit 59de888

Browse files
committed
Initial FA-2 Triton Windows build support
1 parent 5f9fdc9 commit 59de888

File tree

5 files changed

+775
-328
lines changed

5 files changed

+775
-328
lines changed

aiter/dist/parallel_state.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,14 @@
3333

3434
import torch
3535
import torch.distributed
36-
from torch.distributed import Backend, ProcessGroup
36+
try:
37+
from torch.distributed import Backend, ProcessGroup
38+
except ImportError:
39+
# torch.distributed is not available on all Windows ROCm builds.
40+
# Set to None so the module can be imported; code paths that actually
41+
# use these names are guarded by distributed-availability checks.
42+
Backend = None # type: ignore[assignment]
43+
ProcessGroup = None # type: ignore[assignment]
3744

3845
import os
3946
from aiter import logger

aiter/jit/core.py

Lines changed: 117 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,13 @@ def get_user_jit_dir() -> str:
368368
def validate_and_update_archs():
369369
archs = os.getenv("GPU_ARCHS", "native").split(";")
370370
archs = [arch.strip() for arch in archs]
371+
372+
if "native" in archs and sys.platform == "win32":
373+
# hipcc on Windows does not support --offload-arch=native;
374+
# resolve it to the actual gfx string detected from hipinfo/rocminfo.
375+
resolved = get_gfx()
376+
archs = [resolved if a == "native" else a for a in archs]
377+
371378
# List of allowed architectures
372379
allowed_archs = [
373380
"native",
@@ -398,18 +405,58 @@ def validate_and_update_archs():
398405
@functools.lru_cache()
399406
def hip_flag_checker(flag_hip: str) -> bool:
400407
import subprocess
408+
import tempfile
409+
from cpp_extension import get_cxx_compiler, ROCM_HOME
401410

402-
cmd = (
403-
["hipcc"]
404-
+ flag_hip.split()
405-
+ ["-x", "hip", "-E", "-P", "/dev/null", "-o", "/dev/null"]
406-
)
407-
try:
408-
subprocess.check_output(cmd, stderr=subprocess.DEVNULL)
409-
except subprocess.CalledProcessError:
410-
logger.warning(f"Current hipcc not support: {flag_hip}, skip it.")
411-
return False
412-
return True
411+
hipcc_bin = get_cxx_compiler()
412+
413+
# On Windows hipcc.exe uses the HIP_PATH environment variable to find
414+
# HIP headers and device libs — not --rocm-path compiler flags.
415+
# Inject it into the subprocess environment if not already set.
416+
extra_env = {}
417+
if sys.platform == "win32" and ROCM_HOME and "HIP_PATH" not in os.environ:
418+
extra_env["HIP_PATH"] = ROCM_HOME
419+
420+
if sys.platform == "win32":
421+
tmp_in = None
422+
tmp_out = None
423+
try:
424+
with tempfile.NamedTemporaryFile(
425+
suffix=".cu", delete=False, mode="w"
426+
) as f:
427+
f.write("// flag check\n")
428+
tmp_in = f.name
429+
tmp_out = tmp_in.replace(".cu", ".o")
430+
env = {**os.environ, **extra_env}
431+
cmd = (
432+
[hipcc_bin]
433+
+ flag_hip.split()
434+
+ ["-x", "hip", "-c", tmp_in, "-o", tmp_out]
435+
)
436+
subprocess.check_output(cmd, stderr=subprocess.DEVNULL, env=env)
437+
return True
438+
except subprocess.CalledProcessError:
439+
logger.warning(f"Current hipcc not support: {flag_hip}, skip it.")
440+
return False
441+
finally:
442+
for f in filter(None, [tmp_in, tmp_out]):
443+
try:
444+
os.unlink(f)
445+
except OSError:
446+
pass
447+
else:
448+
null_dev = "/dev/null"
449+
cmd = (
450+
[hipcc_bin]
451+
+ flag_hip.split()
452+
+ ["-x", "hip", "-E", "-P", null_dev, "-o", null_dev]
453+
)
454+
try:
455+
subprocess.check_output(cmd, stderr=subprocess.DEVNULL)
456+
except subprocess.CalledProcessError:
457+
logger.warning(f"Current hipcc not support: {flag_hip}, skip it.")
458+
return False
459+
return True
413460

414461

415462
@functools.lru_cache()
@@ -421,14 +468,43 @@ def check_LLVM_MAIN_REVISION():
421468
#else
422469
#define CK_TILE_HOST_DEVICE_EXTERN"""
423470
import subprocess
424-
425-
cmd = """echo "#include <tuple>
471+
import tempfile
472+
from cpp_extension import get_cxx_compiler
473+
474+
src = '#include <tuple>\n__host__ __device__ void func(){std::tuple<int, int> t = std::tuple(1, 1);}\n'
475+
hipcc_bin = get_cxx_compiler()
476+
if sys.platform == "win32":
477+
# On Windows we can't use bash pipes; write to a temp file instead.
478+
with tempfile.NamedTemporaryFile(suffix=".cu", delete=False, mode="w") as f:
479+
f.write(src)
480+
tmp = f.name
481+
try:
482+
subprocess.check_output(
483+
[hipcc_bin, "-x", "hip", "-P", "-c", "-Wno-unused-command-line-argument", tmp],
484+
stderr=subprocess.STDOUT,
485+
text=True,
486+
)
487+
return 554785 - 1
488+
except subprocess.CalledProcessError:
489+
return 554785
490+
finally:
491+
try:
492+
os.unlink(tmp)
493+
except OSError:
494+
pass
495+
obj = tmp.replace(".cu", ".obj")
496+
try:
497+
os.unlink(obj)
498+
except OSError:
499+
pass
500+
else:
501+
cmd = """echo "#include <tuple>
426502
__host__ __device__ void func(){std::tuple<int, int> t = std::tuple(1, 1);}" | hipcc -x hip -P -c -Wno-unused-command-line-argument -o /dev/null -"""
427-
try:
428-
subprocess.check_output(cmd, shell=True, text=True, stderr=subprocess.STDOUT)
429-
except subprocess.CalledProcessError:
430-
return 554785
431-
return 554785 - 1
503+
try:
504+
subprocess.check_output(cmd, shell=True, text=True, stderr=subprocess.STDOUT)
505+
except subprocess.CalledProcessError:
506+
return 554785
507+
return 554785 - 1
432508

433509

434510
def check_and_set_ninja_worker():
@@ -489,6 +565,8 @@ def do_rename_and_mv(name, src, dst, ret):
489565

490566
@torch_compile_guard()
491567
def check_numa_custom_op() -> None:
568+
if sys.platform == "win32":
569+
return # /proc/sys/kernel/numa_balancing does not exist on Windows
492570
numa_balance_set = os.popen("cat /proc/sys/kernel/numa_balancing").read().strip()
493571
if numa_balance_set == "1":
494572
logger.warning(
@@ -654,11 +732,18 @@ def check_git_version(required_major, required_minor):
654732

655733

656734
def rm_module(md_name):
657-
os.system(f"rm -rf {get_user_jit_dir()}/{md_name}.so")
735+
import glob as _glob
736+
_lib_ext = ".pyd" if sys.platform == "win32" else ".so"
737+
for _f in _glob.glob(f"{get_user_jit_dir()}/{md_name}{_lib_ext}"):
738+
try:
739+
os.remove(_f)
740+
except OSError:
741+
pass
658742

659743

660744
def clear_build(md_name):
661-
os.system(f"rm -rf {bd_dir}/{md_name}")
745+
import shutil as _shutil
746+
_shutil.rmtree(f"{bd_dir}/{md_name}", ignore_errors=True)
662747

663748

664749
def build_module(
@@ -679,7 +764,8 @@ def build_module(
679764
os.makedirs(bd_dir, exist_ok=True)
680765
lock_path = f"{bd_dir}/lock_{md_name}"
681766
startTS = time.perf_counter()
682-
target_name = f"{md_name}.so" if not is_standalone else md_name
767+
_lib_ext = ".pyd" if sys.platform == "win32" else ".so"
768+
target_name = f"{md_name}{_lib_ext}" if not is_standalone else md_name
683769

684770
for tp in third_party:
685771
clone_3rdparty(tp)
@@ -875,9 +961,9 @@ def _get_ck_exclude_modules():
875961
"""Return set of module names that require CK and should be excluded in CK-free builds.
876962
877963
Combines two detection methods:
878-
1. Config pattern matching -- modules whose optCompilerConfig.json entry references
964+
1. Config pattern matching modules whose optCompilerConfig.json entry references
879965
CK_DIR, py_itfs_ck, gen_instances, or generate.py
880-
2. Hardcoded list -- modules with deep ck_tile:: source-level dependencies that
966+
2. Hardcoded list modules with deep ck_tile:: source-level dependencies that
881967
aren't caught by config pattern matching
882968
883969
V3 ASM modules are exempted because they build with shim headers only.
@@ -897,7 +983,7 @@ def _get_ck_exclude_modules():
897983
if any(p in mod_str for p in ck_patterns):
898984
ck_modules.add(mod_name)
899985

900-
# V3 ASM modules can build with shim headers -- exempt them
986+
# V3 ASM modules can build with shim headers exempt them
901987
v3_flags = ["FAV3_ON", "ONLY_FAV3"]
902988
for mod_name, mod_cfg in config_data.items():
903989
flags_str = json.dumps(mod_cfg.get("flags_extra_cc", []))
@@ -1255,6 +1341,7 @@ def wrapper(*args, custom_build_args={}, **kwargs):
12551341
d_args = get_args_of_build(md_name)
12561342
d_args.update(custom_build_args)
12571343

1344+
# update module name if we have a custom build
12581345
md_name = custom_build_args.get("md_name", md_name)
12591346

12601347
srcs = d_args["srcs"]
@@ -1324,6 +1411,7 @@ def check_args():
13241411
doc_str = doc_str.replace("collections.abc.Sequence[", "List[")
13251412
doc_str = doc_str.replace("typing.SupportsInt", "int")
13261413
doc_str = doc_str.replace("typing.SupportsFloat", "float")
1414+
# A|None --> Optional[A]
13271415
pattern = r"([\w\.]+(?:\[[^\]]+\])?)\s*\|\s*None"
13281416
doc_str = re.sub(pattern, r"Optional[\1]", doc_str)
13291417
for el in enum_types:
@@ -1353,14 +1441,18 @@ def check_args():
13531441

13541442
if origin is None:
13551443
if not isinstance(arg, expected_type) and not (
1444+
# aiter_enum can be int
13561445
any(el in str(expected_type) for el in enum_types)
13571446
and isinstance(arg, int)
13581447
):
13591448
raise TypeError(
13601449
f"{loadName}: {el} needs to be {expected_type} but got {got_type}"
13611450
)
13621451
elif origin is list:
1363-
if not isinstance(arg, list):
1452+
if (
1453+
not isinstance(arg, list)
1454+
# or not all(isinstance(i, sub_t) for i in arg)
1455+
):
13641456
raise TypeError(
13651457
f"{loadName}: {el} needs to be List[{sub_t}] but got {arg}"
13661458
)

0 commit comments

Comments
 (0)