@@ -368,6 +368,13 @@ def get_user_jit_dir() -> str:
368368def validate_and_update_archs ():
369369 archs = os .getenv ("GPU_ARCHS" , "native" ).split (";" )
370370 archs = [arch .strip () for arch in archs ]
371+
372+ if "native" in archs and sys .platform == "win32" :
373+ # hipcc on Windows does not support --offload-arch=native;
374+ # resolve it to the actual gfx string detected from hipinfo/rocminfo.
375+ resolved = get_gfx ()
376+ archs = [resolved if a == "native" else a for a in archs ]
377+
371378 # List of allowed architectures
372379 allowed_archs = [
373380 "native" ,
@@ -398,18 +405,58 @@ def validate_and_update_archs():
398405@functools .lru_cache ()
399406def hip_flag_checker (flag_hip : str ) -> bool :
400407 import subprocess
408+ import tempfile
409+ from cpp_extension import get_cxx_compiler , ROCM_HOME
401410
402- cmd = (
403- ["hipcc" ]
404- + flag_hip .split ()
405- + ["-x" , "hip" , "-E" , "-P" , "/dev/null" , "-o" , "/dev/null" ]
406- )
407- try :
408- subprocess .check_output (cmd , stderr = subprocess .DEVNULL )
409- except subprocess .CalledProcessError :
410- logger .warning (f"Current hipcc not support: { flag_hip } , skip it." )
411- return False
412- return True
411+ hipcc_bin = get_cxx_compiler ()
412+
413+ # On Windows hipcc.exe uses the HIP_PATH environment variable to find
414+ # HIP headers and device libs — not --rocm-path compiler flags.
415+ # Inject it into the subprocess environment if not already set.
416+ extra_env = {}
417+ if sys .platform == "win32" and ROCM_HOME and "HIP_PATH" not in os .environ :
418+ extra_env ["HIP_PATH" ] = ROCM_HOME
419+
420+ if sys .platform == "win32" :
421+ tmp_in = None
422+ tmp_out = None
423+ try :
424+ with tempfile .NamedTemporaryFile (
425+ suffix = ".cu" , delete = False , mode = "w"
426+ ) as f :
427+ f .write ("// flag check\n " )
428+ tmp_in = f .name
429+ tmp_out = tmp_in .replace (".cu" , ".o" )
430+ env = {** os .environ , ** extra_env }
431+ cmd = (
432+ [hipcc_bin ]
433+ + flag_hip .split ()
434+ + ["-x" , "hip" , "-c" , tmp_in , "-o" , tmp_out ]
435+ )
436+ subprocess .check_output (cmd , stderr = subprocess .DEVNULL , env = env )
437+ return True
438+ except subprocess .CalledProcessError :
439+ logger .warning (f"Current hipcc not support: { flag_hip } , skip it." )
440+ return False
441+ finally :
442+ for f in filter (None , [tmp_in , tmp_out ]):
443+ try :
444+ os .unlink (f )
445+ except OSError :
446+ pass
447+ else :
448+ null_dev = "/dev/null"
449+ cmd = (
450+ [hipcc_bin ]
451+ + flag_hip .split ()
452+ + ["-x" , "hip" , "-E" , "-P" , null_dev , "-o" , null_dev ]
453+ )
454+ try :
455+ subprocess .check_output (cmd , stderr = subprocess .DEVNULL )
456+ except subprocess .CalledProcessError :
457+ logger .warning (f"Current hipcc not support: { flag_hip } , skip it." )
458+ return False
459+ return True
413460
414461
415462@functools .lru_cache ()
@@ -421,14 +468,43 @@ def check_LLVM_MAIN_REVISION():
421468 #else
422469 #define CK_TILE_HOST_DEVICE_EXTERN"""
423470 import subprocess
424-
425- cmd = """echo "#include <tuple>
471+ import tempfile
472+ from cpp_extension import get_cxx_compiler
473+
474+ src = '#include <tuple>\n __host__ __device__ void func(){std::tuple<int, int> t = std::tuple(1, 1);}\n '
475+ hipcc_bin = get_cxx_compiler ()
476+ if sys .platform == "win32" :
477+ # On Windows we can't use bash pipes; write to a temp file instead.
478+ with tempfile .NamedTemporaryFile (suffix = ".cu" , delete = False , mode = "w" ) as f :
479+ f .write (src )
480+ tmp = f .name
481+ try :
482+ subprocess .check_output (
483+ [hipcc_bin , "-x" , "hip" , "-P" , "-c" , "-Wno-unused-command-line-argument" , tmp ],
484+ stderr = subprocess .STDOUT ,
485+ text = True ,
486+ )
487+ return 554785 - 1
488+ except subprocess .CalledProcessError :
489+ return 554785
490+ finally :
491+ try :
492+ os .unlink (tmp )
493+ except OSError :
494+ pass
495+ obj = tmp .replace (".cu" , ".obj" )
496+ try :
497+ os .unlink (obj )
498+ except OSError :
499+ pass
500+ else :
501+ cmd = """echo "#include <tuple>
426502__host__ __device__ void func(){std::tuple<int, int> t = std::tuple(1, 1);}" | hipcc -x hip -P -c -Wno-unused-command-line-argument -o /dev/null -"""
427- try :
428- subprocess .check_output (cmd , shell = True , text = True , stderr = subprocess .STDOUT )
429- except subprocess .CalledProcessError :
430- return 554785
431- return 554785 - 1
503+ try :
504+ subprocess .check_output (cmd , shell = True , text = True , stderr = subprocess .STDOUT )
505+ except subprocess .CalledProcessError :
506+ return 554785
507+ return 554785 - 1
432508
433509
434510def check_and_set_ninja_worker ():
@@ -489,6 +565,8 @@ def do_rename_and_mv(name, src, dst, ret):
489565
490566@torch_compile_guard ()
491567def check_numa_custom_op () -> None :
568+ if sys .platform == "win32" :
569+ return # /proc/sys/kernel/numa_balancing does not exist on Windows
492570 numa_balance_set = os .popen ("cat /proc/sys/kernel/numa_balancing" ).read ().strip ()
493571 if numa_balance_set == "1" :
494572 logger .warning (
@@ -654,11 +732,18 @@ def check_git_version(required_major, required_minor):
654732
655733
656734def rm_module (md_name ):
657- os .system (f"rm -rf { get_user_jit_dir ()} /{ md_name } .so" )
735+ import glob as _glob
736+ _lib_ext = ".pyd" if sys .platform == "win32" else ".so"
737+ for _f in _glob .glob (f"{ get_user_jit_dir ()} /{ md_name } { _lib_ext } " ):
738+ try :
739+ os .remove (_f )
740+ except OSError :
741+ pass
658742
659743
660744def clear_build (md_name ):
661- os .system (f"rm -rf { bd_dir } /{ md_name } " )
745+ import shutil as _shutil
746+ _shutil .rmtree (f"{ bd_dir } /{ md_name } " , ignore_errors = True )
662747
663748
664749def build_module (
@@ -679,7 +764,8 @@ def build_module(
679764 os .makedirs (bd_dir , exist_ok = True )
680765 lock_path = f"{ bd_dir } /lock_{ md_name } "
681766 startTS = time .perf_counter ()
682- target_name = f"{ md_name } .so" if not is_standalone else md_name
767+ _lib_ext = ".pyd" if sys .platform == "win32" else ".so"
768+ target_name = f"{ md_name } { _lib_ext } " if not is_standalone else md_name
683769
684770 for tp in third_party :
685771 clone_3rdparty (tp )
@@ -875,9 +961,9 @@ def _get_ck_exclude_modules():
875961 """Return set of module names that require CK and should be excluded in CK-free builds.
876962
877963 Combines two detection methods:
878- 1. Config pattern matching -- modules whose optCompilerConfig.json entry references
964+ 1. Config pattern matching — modules whose optCompilerConfig.json entry references
879965 CK_DIR, py_itfs_ck, gen_instances, or generate.py
880- 2. Hardcoded list -- modules with deep ck_tile:: source-level dependencies that
966+ 2. Hardcoded list — modules with deep ck_tile:: source-level dependencies that
881967 aren't caught by config pattern matching
882968
883969 V3 ASM modules are exempted because they build with shim headers only.
@@ -897,7 +983,7 @@ def _get_ck_exclude_modules():
897983 if any (p in mod_str for p in ck_patterns ):
898984 ck_modules .add (mod_name )
899985
900- # V3 ASM modules can build with shim headers -- exempt them
986+ # V3 ASM modules can build with shim headers — exempt them
901987 v3_flags = ["FAV3_ON" , "ONLY_FAV3" ]
902988 for mod_name , mod_cfg in config_data .items ():
903989 flags_str = json .dumps (mod_cfg .get ("flags_extra_cc" , []))
@@ -1255,6 +1341,7 @@ def wrapper(*args, custom_build_args={}, **kwargs):
12551341 d_args = get_args_of_build (md_name )
12561342 d_args .update (custom_build_args )
12571343
1344+ # update module name if we have a custom build
12581345 md_name = custom_build_args .get ("md_name" , md_name )
12591346
12601347 srcs = d_args ["srcs" ]
@@ -1324,6 +1411,7 @@ def check_args():
13241411 doc_str = doc_str .replace ("collections.abc.Sequence[" , "List[" )
13251412 doc_str = doc_str .replace ("typing.SupportsInt" , "int" )
13261413 doc_str = doc_str .replace ("typing.SupportsFloat" , "float" )
1414+ # A|None --> Optional[A]
13271415 pattern = r"([\w\.]+(?:\[[^\]]+\])?)\s*\|\s*None"
13281416 doc_str = re .sub (pattern , r"Optional[\1]" , doc_str )
13291417 for el in enum_types :
@@ -1353,14 +1441,18 @@ def check_args():
13531441
13541442 if origin is None :
13551443 if not isinstance (arg , expected_type ) and not (
1444+ # aiter_enum can be int
13561445 any (el in str (expected_type ) for el in enum_types )
13571446 and isinstance (arg , int )
13581447 ):
13591448 raise TypeError (
13601449 f"{ loadName } : { el } needs to be { expected_type } but got { got_type } "
13611450 )
13621451 elif origin is list :
1363- if not isinstance (arg , list ):
1452+ if (
1453+ not isinstance (arg , list )
1454+ # or not all(isinstance(i, sub_t) for i in arg)
1455+ ):
13641456 raise TypeError (
13651457 f"{ loadName } : { el } needs to be List[{ sub_t } ] but got { arg } "
13661458 )
0 commit comments