|
21 | 21 | from functools import partial |
22 | 22 | import os |
23 | 23 | import socket |
| 24 | +import re |
24 | 25 | import subprocess |
25 | 26 | import time |
26 | 27 | from typing import Any |
|
52 | 53 | # pylint: disable=too-many-positional-arguments |
53 | 54 |
|
54 | 55 |
|
| 56 | +def parse_libtpu_flags_to_dict(flags_str: str) -> dict: |
| 57 | + """ |
| 58 | + Parses a string of LIBTPU flags into a dictionary of compilation options. |
| 59 | + This function is only for internal compilation usage. |
| 60 | + """ |
| 61 | + if not flags_str or not flags_str.strip(): |
| 62 | + return {} |
| 63 | + |
| 64 | + # Clean the string by removing line-continuation backslashes |
| 65 | + cleaned_str = flags_str.replace("\\", " ") |
| 66 | + |
| 67 | + # Split by any whitespace (handles single spaces, multiple spaces, newlines) |
| 68 | + tokens = cleaned_str.split() |
| 69 | + |
| 70 | + options_dict = {} |
| 71 | + |
| 72 | + # Regex to strictly match '--key=value' for an isolated token |
| 73 | + # Key assumes alphanumeric + underscores. Value is anything after the '='. |
| 74 | + token_pattern = re.compile(r"^--([a-zA-Z0-9_]+)=(.+)$") |
| 75 | + |
| 76 | + for token in tokens: |
| 77 | + match = token_pattern.match(token) |
| 78 | + if not match: |
| 79 | + # Throw an error immediately if any token fails the strict format |
| 80 | + raise ValueError(f"Invalid flag format detected: '{token}'. Expected format: '--key=value'") |
| 81 | + |
| 82 | + key, value = match.groups() |
| 83 | + |
| 84 | + # Optional: Catch duplicate flags |
| 85 | + if key in options_dict: |
| 86 | + raise ValueError(f"Duplicate flag detected: '--{key}'") |
| 87 | + |
| 88 | + options_dict[key] = value |
| 89 | + |
| 90 | + return options_dict |
| 91 | + |
| 92 | + |
55 | 93 | def with_memory_kind(t, memory_kind): |
56 | 94 | return jax.tree_util.tree_map(lambda x: x.with_memory_kind(kind=memory_kind), t) |
57 | 95 |
|
|
0 commit comments