|
6 | 6 | import time |
7 | 7 | import shutil |
8 | 8 | from tqdm import tqdm |
| 9 | +import argparse |
| 10 | +import yaml |
| 11 | +from typing import Dict, Any |
9 | 12 |
|
10 | 13 | # Container: insert parent GSAS-II directory and import package |
11 | 14 | GSAS_II_PARENT = os.environ.get("GSAS_II_PATH", "/opt/conda/envs/sim/GSAS-II") |
@@ -101,7 +104,7 @@ def run(self, n_sims_per_file, master_seed=12345, cleanup_worker_dirs=True, **kw |
101 | 104 | simulation_start_time = time.time() |
102 | 105 | success_count = 0 |
103 | 106 |
|
104 | | - with Pool(self.n_parallel_sims, maxtasksperchild=1000) as p: |
| 107 | + with Pool(self.n_parallel_sims, maxtasksperchild=1000, initializer=simulation_worker.suppress_worker_stdout) as p: |
105 | 108 | results = tqdm(p.imap_unordered(simulation_worker.run_single_simulation, tasks), total=total_jobs_to_run) |
106 | 109 | for result in results: |
107 | 110 | if result: success_count += 1 |
@@ -129,3 +132,91 @@ def run(self, n_sims_per_file, master_seed=12345, cleanup_worker_dirs=True, **kw |
129 | 132 | print(f"Throughput: {sims_per_sec:.2f} simulations/sec") |
130 | 133 | print(f"Avg. Time/Simulation: {avg_time_per_sim:.3f} seconds") |
131 | 134 | print("---------------------------") |
| 135 | + |
| 136 | + |
| 137 | +def load_config(path: Path) -> Dict[str, Any]: |
| 138 | + if not path.exists(): |
| 139 | + raise FileNotFoundError(f"Config file not found: {path}") |
| 140 | + with open(path, "r") as f: |
| 141 | + cfg = yaml.safe_load(f) or {} |
| 142 | + return cfg |
| 143 | + |
| 144 | + |
| 145 | +def parse_args() -> argparse.Namespace: |
| 146 | + parser = argparse.ArgumentParser(description="AlphaDiffract simulator runner") |
| 147 | + parser.add_argument("--config", type=str, default="/app/configs/simulator.yaml", help="Path to YAML config file") |
| 148 | + return parser.parse_args() |
| 149 | + |
| 150 | + |
| 151 | +def main() -> None: |
| 152 | + args = parse_args() |
| 153 | + cfg = load_config(Path(args.config)) |
| 154 | + |
| 155 | + # Paths |
| 156 | + input_directory = str(Path(cfg["input_directory"])) |
| 157 | + output_directory = str(Path(cfg["output_directory"])) |
| 158 | + instprm_file = str(Path(cfg["instprm_file"])) |
| 159 | + error_directory = str(Path(cfg["error_directory"])) |
| 160 | + worker_base_dir = str(Path(cfg["worker_base_dir"])) |
| 161 | + |
| 162 | + # Execution controls |
| 163 | + parallel_jobs = int(cfg["parallel_jobs"]) |
| 164 | + sims_per_file = int(cfg["sims_per_file"]) |
| 165 | + master_seed = int(cfg["master_seed"]) |
| 166 | + cleanup_worker_dirs = bool(cfg["cleanup_worker_dirs"]) |
| 167 | + |
| 168 | + # Parameter ranges |
| 169 | + ranges = { |
| 170 | + "strain_range": tuple(map(float, cfg["strain_range"])), |
| 171 | + "size_range": tuple(map(float, cfg["size_range"])), |
| 172 | + "U_range": tuple(map(float, cfg["U_range"])), |
| 173 | + "V_range": tuple(map(float, cfg["V_range"])), |
| 174 | + "W_range": tuple(map(float, cfg["W_range"])), |
| 175 | + "st_range": tuple(map(float, cfg["st_range"])), |
| 176 | + "en_range": tuple(map(float, cfg["en_range"])), |
| 177 | + "Npoints_range": tuple(map(int, cfg["Npoints_range"])), |
| 178 | + "scaler_range": tuple(map(float, cfg["scaler_range"])), |
| 179 | + "wl_range": tuple(map(float, cfg["wl_range"])), |
| 180 | + "proportional_noise_range": tuple(map(float, cfg["proportional_noise_range"])), |
| 181 | + "constant_noise_range": tuple(map(float, cfg["constant_noise_range"])), |
| 182 | + } |
| 183 | + |
| 184 | + # Log planned operation |
| 185 | + print("\n--- Simulator configuration ---") |
| 186 | + print(f"Input directory: {input_directory}") |
| 187 | + print(f"Output directory: {output_directory}") |
| 188 | + print(f"Error directory: {error_directory}") |
| 189 | + print(f"Worker base dir: {worker_base_dir}") |
| 190 | + print(f"Instrument file: {instprm_file}") |
| 191 | + print(f"Parallel jobs: {parallel_jobs}") |
| 192 | + print(f"Sims per file: {sims_per_file}") |
| 193 | + print(f"Master seed: {master_seed}") |
| 194 | + print(f"Cleanup worker dirs: {cleanup_worker_dirs}") |
| 195 | + print("Parameter ranges:") |
| 196 | + for k, v in ranges.items(): |
| 197 | + print(f" {k}: {v}") |
| 198 | + print("--------------------------------\n") |
| 199 | + |
| 200 | + # Run simulations |
| 201 | + try: |
| 202 | + gen = DiffractionGenerator( |
| 203 | + input_dir=input_directory, |
| 204 | + output_dir=output_directory, |
| 205 | + instprm_file=instprm_file, |
| 206 | + n_parallel_sims=parallel_jobs, |
| 207 | + error_dir=error_directory, |
| 208 | + worker_base_dir=worker_base_dir, |
| 209 | + ) |
| 210 | + gen.run( |
| 211 | + n_sims_per_file=sims_per_file, |
| 212 | + master_seed=master_seed, |
| 213 | + cleanup_worker_dirs=cleanup_worker_dirs, |
| 214 | + **ranges, |
| 215 | + ) |
| 216 | + except Exception as e: |
| 217 | + print(f"FATAL: Simulator run failed: {e}", file=sys.stderr) |
| 218 | + sys.exit(1) |
| 219 | + |
| 220 | + |
| 221 | +if __name__ == "__main__": |
| 222 | + main() |
0 commit comments