Skip to content

Commit ec2db3d

Browse files
committed
Update generated requirements for JAX 0.9.2
1 parent 4efab25 commit ec2db3d

19 files changed

Lines changed: 587 additions & 667 deletions

.github/workflows/build_and_test_maxtext.yml

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ concurrency:
3737
permissions:
3838
contents: read
3939
jobs:
40-
doc_only_check:
41-
name: Check for Documentation-Only Changes
40+
analyze_changes:
41+
name: Analyze Changes for Test Orchestration
4242
runs-on: ubuntu-latest
4343
outputs:
4444
run_tests: ${{ steps.check.outputs.run_tests }}
@@ -47,7 +47,7 @@ jobs:
4747
- uses: actions/checkout@v4
4848
with:
4949
fetch-depth: 0
50-
- name: Check if only documentation changed
50+
- name: Analyze Changes for Test Orchestration
5151
id: check
5252
run: |
5353
if [ "${{ github.event_name }}" != "pull_request" ]; then
@@ -70,6 +70,14 @@ jobs:
7070
exit 0
7171
fi
7272
73+
# Check if dependencies are changed
74+
if echo "$CHANGED_FILES" | grep -E '(^|/)src/dependencies/' > /dev/null; then
75+
echo "MaxText dependencies changed, enabling all tests and notebooks."
76+
echo "run_tests=true" >> $GITHUB_OUTPUT
77+
echo "run_notebooks=true" >> $GITHUB_OUTPUT
78+
exit 0
79+
fi
80+
7381
# Check for source code changes (anything not .md and not .ipynb)
7482
if echo "$CHANGED_FILES" | grep -v -E '\.(md|ipynb)$' > /dev/null; then
7583
echo "Source code files changed, enabling unit tests."
@@ -91,11 +99,11 @@ jobs:
9199
exit 0
92100
93101
build_and_upload_maxtext_package:
94-
needs: doc_only_check
102+
needs: analyze_changes
95103
# Run if either tests or notebooks need to run
96104
if: |
97-
needs.doc_only_check.outputs.run_tests == 'true' ||
98-
needs.doc_only_check.outputs.run_notebooks == 'true'
105+
needs.analyze_changes.outputs.run_tests == 'true' ||
106+
needs.analyze_changes.outputs.run_notebooks == 'true'
99107
uses: ./.github/workflows/build_package.yml
100108
with:
101109
device_type: tpu
@@ -104,7 +112,7 @@ jobs:
104112

105113
maxtext_jupyter_notebooks:
106114
needs: build_and_upload_maxtext_package
107-
if: needs.doc_only_check.outputs.run_notebooks == 'true'
115+
if: needs.analyze_changes.outputs.run_notebooks == 'true'
108116
uses: ./.github/workflows/run_jupyter_notebooks.yml
109117
strategy:
110118
fail-fast: false
@@ -120,7 +128,7 @@ jobs:
120128
tpu-tests:
121129
name: ${{ matrix.flavor }} tests
122130
needs: [build_and_upload_maxtext_package]
123-
if: needs.doc_only_check.outputs.run_tests == 'true'
131+
if: needs.analyze_changes.outputs.run_tests == 'true'
124132
uses: ./.github/workflows/run_tests_coordinator.yml
125133
strategy:
126134
fail-fast: false
@@ -135,7 +143,7 @@ jobs:
135143
gpu-tests:
136144
name: ${{ matrix.flavor }} tests
137145
needs: [build_and_upload_maxtext_package]
138-
if: needs.doc_only_check.outputs.run_tests == 'true'
146+
if: needs.analyze_changes.outputs.run_tests == 'true'
139147
strategy:
140148
fail-fast: false
141149
matrix:
@@ -150,7 +158,7 @@ jobs:
150158
cpu-tests:
151159
name: ${{ matrix.flavor }} tests
152160
needs: [build_and_upload_maxtext_package]
153-
if: needs.doc_only_check.outputs.run_tests == 'true'
161+
if: needs.analyze_changes.outputs.run_tests == 'true'
154162
uses: ./.github/workflows/run_tests_coordinator.yml
155163
strategy:
156164
fail-fast: false
@@ -164,7 +172,7 @@ jobs:
164172

165173
maxtext_tpu_pathways_unit_tests:
166174
needs: build_and_upload_maxtext_package
167-
if: needs.doc_only_check.outputs.run_tests == 'true'
175+
if: needs.analyze_changes.outputs.run_tests == 'true'
168176
uses: ./.github/workflows/run_pathways_tests.yml
169177
strategy:
170178
fail-fast: false
@@ -182,7 +190,7 @@ jobs:
182190

183191
maxtext_tpu_pathways_integration_tests:
184192
needs: build_and_upload_maxtext_package
185-
if: needs.doc_only_check.outputs.run_tests == 'true'
193+
if: needs.analyze_changes.outputs.run_tests == 'true'
186194
uses: ./.github/workflows/run_pathways_tests.yml
187195
strategy:
188196
fail-fast: false
@@ -206,9 +214,8 @@ jobs:
206214
steps:
207215
- name: Check test results
208216
run: |
209-
# If doc-only, all tests should be skipped
210-
if [ "${NEEDS_DOC_ONLY_CHECK_OUTPUTS_RUN_TESTS}" == "false" ]; then
211-
echo "Documentation-only changes detected, tests were skipped"
217+
if [ "${NEEDS_ANALYZE_CHANGES_OUTPUTS_RUN_TESTS}" == "false" ]; then
218+
echo "Tests were skipped"
212219
exit 0
213220
fi
214221
@@ -228,7 +235,7 @@ jobs:
228235
229236
echo "All required tests passed successfully"
230237
env:
231-
NEEDS_DOC_ONLY_CHECK_OUTPUTS_RUN_TESTS: ${{ needs.doc_only_check.outputs.run_tests }}
238+
NEEDS_ANALYZE_CHANGES_OUTPUTS_RUN_TESTS: ${{ needs.analyze_changes.outputs.run_tests }}
232239
NEEDS_BUILD_AND_UPLOAD_MAXTEXT_PACKAGE_RESULT: ${{ needs.build_and_upload_maxtext_package.result }}
233240
NEEDS_CPU_TESTS_RESULT: ${{ needs.cpu-tests.result }}
234241
NEEDS_TPU_TESTS_RESULT: ${{ needs.tpu-tests.result }}
@@ -238,14 +245,14 @@ jobs:
238245

239246
all_notebooks_passed:
240247
name: All Notebooks Passed
241-
needs: [doc_only_check, build_and_upload_maxtext_package, maxtext_jupyter_notebooks]
248+
needs: [analyze_changes, build_and_upload_maxtext_package, maxtext_jupyter_notebooks]
242249
if: always()
243250
runs-on: ubuntu-latest
244251
steps:
245252
- name: Check notebooks results
246253
run: |
247-
if [ "${NEEDS_DOC_ONLY_CHECK_OUTPUTS_RUN_NOTEBOOKS}" == "false" ]; then
248-
echo "Non-notebook changes detected, runs were skipped"
254+
if [ "${NEEDS_ANALYZE_CHANGES_OUTPUTS_RUN_NOTEBOOKS}" == "false" ]; then
255+
echo "Notebooks were skipped"
249256
exit 0
250257
fi
251258
@@ -261,7 +268,7 @@ jobs:
261268
262269
echo "All required notebooks passed successfully"
263270
env:
264-
NEEDS_DOC_ONLY_CHECK_OUTPUTS_RUN_NOTEBOOKS: ${{ needs.doc_only_check.outputs.run_notebooks }}
271+
NEEDS_ANALYZE_CHANGES_OUTPUTS_RUN_NOTEBOOKS: ${{ needs.analyze_changes.outputs.run_notebooks }}
265272
NEEDS_BUILD_AND_UPLOAD_MAXTEXT_PACKAGE_RESULT: ${{ needs.build_and_upload_maxtext_package.result }}
266273
NEEDS_MAXTEXT_JUPYTER_NOTEBOOKS_RESULT: ${{ needs.maxtext_jupyter_notebooks.result }}
267274

.github/workflows/run_tests_against_package.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,14 @@ jobs:
152152
# omit this libtpu init args for gpu tests
153153
if [ "${INPUTS_DEVICE_TYPE}" != "cuda12" ]; then
154154
export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=65536'
155+
else
156+
# For cuda12, explicitly point to the pip-installed CUDA libraries
157+
# to avoid conflicts with system-level installations on the runner.
158+
if [ -d ".venv/lib/python3.12/site-packages/nvidia" ]; then
159+
export LD_LIBRARY_PATH=$(pwd)/.venv/lib/python3.12/site-packages/nvidia/cudnn/lib:${LD_LIBRARY_PATH}
160+
else
161+
echo "Warning: Could not find pinned nvidia libraries in .venv."
162+
fi
155163
fi
156164
if [ "${INPUTS_TOTAL_WORKERS}" -gt 1 ]; then
157165
$PYTHON_EXE -m pip install --quiet pytest-split pytest-xdist

docs/install_maxtext.md

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,10 @@ Please keep dependencies updated throughout development. This will allow each co
118118

119119
To update dependencies, you will follow these general steps:
120120

121-
1. **Modify Base Requirements**: Update the desired dependencies in `base_requirements/requirements.txt` or the hardware-specific files (`base_requirements/tpu-base-requirements.txt`, `base_requirements/gpu-base-requirements.txt`).
121+
1. **Modify Base Requirements**: Update the desired dependencies in `src/dependencies/requirements/base_requirements/requirements.txt` or the hardware-specific pre-training files (`base_requirements/tpu-requirements.txt`, `base_requirements/gpu-requirements.txt`) or post-training requirements.
122122
2. **Generate New Files**: Run the `seed-env` CLI tool to generate new, fully-pinned requirements files based on your changes.
123123
3. **Update Project Files**: Copy the newly generated files into the `generated_requirements/` directory.
124-
4. **Handle GitHub Dependencies**: Move any dependencies that are installed directly from GitHub from the generated files to `src/dependencies/github_deps/pre_train_deps.txt`.
125-
5. **Verify**: Test the new dependencies to ensure the project installs and runs correctly.
124+
4. **Verify**: Test the new dependencies to ensure the project installs and runs correctly.
126125

127126
The following sections provide detailed instructions for each step.
128127

@@ -133,59 +132,70 @@ First, you need to install the `seed-env` command-line tool by running `pip inst
133132

134133
## Step 2: Find the JAX Build Commit Hash
135134

136-
The dependency generation process is pinned to a specific nightly build of JAX. You need to find the commit hash for the desired JAX build.
137-
138-
You can find the latest commit hashes in the [JAX `build/` folder](https://github.com/jax-ml/jax/commits/main/build). Choose a recent, successful build and copy its full commit hash.
135+
The dependency generation process is pinned to a specific nightly build of JAX. You need to find the commit hash for the desired JAX build from [JAX `build/` folder](https://github.com/jax-ml/jax/commits/main/build).
139136

140137
## Step 3: Generate the Requirements Files
141138

142139
Next, run the `seed-env` CLI to generate the new requirements files. You will need to do this separately for the TPU and GPU environments. The generated files will be placed in a directory specified by `--output-dir`.
143140

144-
### For TPU
141+
> **Note:** The current `src/dependencies/requirements/generated_requirements/` in the repository were generated using JAX build commit hash: [e0d2967b50abbefd651d563dbcd7afbcb963d08c](https://github.com/jax-ml/jax/commit/e0d2967b50abbefd651d563dbcd7afbcb963d08c).
142+
143+
### TPU Pre-Training
145144

146-
Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step.
145+
If you have made changes to TPU pre-training dependencies in `src/dependencies/requirements/base_requirements/tpu-requirements.txt`, you need to regenerate the pinned pre-training requirements in `generated_requirements/` directory. Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step:
147146

148147
```bash
149148
seed-env \
150-
--local-requirements=src/dependencies/requirements/base_requirements/tpu-base-requirements.txt \
149+
--local-requirements=src/dependencies/requirements/base_requirements/tpu-requirements.txt \
151150
--host-name=MaxText \
152151
--seed-commit=<jax-build-commit-hash> \
153152
--python-version=3.12 \
154153
--requirements-txt=tpu-requirements.txt \
155154
--output-dir=generated_tpu_artifacts
155+
156+
# Copy generated requirements to src/dependencies/requirements/generated_requirements
157+
mv generated_tpu_artifacts/tpu-requirements.txt src/dependencies/requirements/generated_requirements/tpu-requirements.txt
156158
```
157159

158-
### For GPU
160+
#### TPU Post-Training
159161

160-
Similarly, run the command for the GPU requirements.
162+
If you have made changes to the post-training dependencies in `src/dependencies/requirements/base_requirements/tpu-post-train-requirements.txt`, you need to regenerate the pinned post-training requirements in `generated_requirements/` directory. Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step:
161163

162164
```bash
163165
seed-env \
164-
--local-requirements=src/dependencies/requirements/base_requirements/cuda12-base-requirements.txt \
166+
--local-requirements=src/dependencies/requirements/base_requirements/tpu-post-train-requirements.txt \
165167
--host-name=MaxText \
166168
--seed-commit=<jax-build-commit-hash> \
167169
--python-version=3.12 \
168-
--requirements-txt=cuda12-requirements.txt \
169-
--hardware=cuda12 \
170-
--output-dir=generated_gpu_artifacts
171-
```
170+
--requirements-txt=tpu-post-train-requirements.txt \
171+
--output-dir=generated_tpu_post_train_artifacts
172172

173-
## Step 4: Update Project Files
173+
# Copy generated requirements to src/dependencies/requirements/generated_requirements
174+
mv generated_tpu_post_train_artifacts/tpu-post-train-requirements.txt src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt
175+
```
174176

175-
After generating the new requirements, you need to update the files in the MaxText repository.
177+
### GPU Pre-Training
176178

177-
1. **Copy the generated files:**
179+
If you have made changes to the GPU pre-training dependencies in `src/dependencies/requirements/base_requirements/gpu-requirements.txt`, you need to regenerate the pinned pre-training requirements in `generated_requirements/` directory. Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step:
178180

179-
- Move `generated_tpu_artifacts/tpu-requirements.txt` to `generated_requirements/tpu-requirements.txt`.
180-
- Move `generated_gpu_artifacts/cuda12-requirements.txt` to `generated_requirements/cuda12-requirements.txt`.
181+
```bash
182+
seed-env \
183+
--local-requirements=src/dependencies/requirements/base_requirements/cuda12-requirements.txt \
184+
--host-name=MaxText \
185+
--seed-commit=<jax-build-commit-hash> \
186+
--python-version=3.12 \
187+
--requirements-txt=cuda12-requirements.txt \
188+
--hardware=cuda12 \
189+
--output-dir=generated_gpu_artifacts
181190

182-
2. **Update `pre_train_deps.txt` (if necessary):**
183-
Currently, MaxText uses a few dependencies, such as `mlperf-logging` and `google-jetstream`, that are installed directly from GitHub source. These are defined in `base_requirements/requirements.txt`, and the `seed-env` tool will carry them over to the generated requirements files.
191+
# Copy generated requirements to src/dependencies/requirements/generated_requirements
192+
mv generated_gpu_artifacts/cuda12-requirements.txt.txt src/dependencies/requirements/generated_requirements/cuda12-requirements.txt.txt
193+
```
184194

185-
## Step 5: Verify the New Dependencies
195+
## Step 4: Verify the New Dependencies
186196

187197
Finally, test that the new dependencies install correctly and that MaxText runs as expected.
188198

189-
1. **Install MaxText and dependencies**: For instructions on installing MaxText on your VM, please refer to the [official documentation](https://maxtext.readthedocs.io/en/maxtext-v0.2.0/install_maxtext.html#from-source).
199+
1. **Install MaxText and dependencies**: For instructions on installing MaxText on your VM, please refer to the [official documentation](https://maxtext.readthedocs.io/en/latest/install_maxtext.html#from-source).
190200

191201
2. **Verify the installation**: Run MaxText tests to ensure everything is working as expected with the newly installed dependencies and there are no regressions.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
google-metrax>=0.2.3

src/dependencies/requirements/base_requirements/gpu-base-requirements.txt renamed to src/dependencies/requirements/base_requirements/cuda12-requirements.txt

File renamed without changes.

src/dependencies/requirements/base_requirements/requirements.txt

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
absl-py
22
aqtp
33
array-record
4+
chex
45
cloud-accelerator-diagnostics
5-
cloud-tpu-diagnostics
6+
cloud-tpu-diagnostics!=1.1.14
67
datasets
78
drjax
89
flax
@@ -24,6 +25,7 @@ numpy
2425
omegaconf
2526
optax
2627
orbax-checkpoint
28+
parameterized
2729
pathwaysutils
2830
pillow
2931
pre-commit
@@ -34,15 +36,14 @@ pylint
3436
pytest
3537
pytype
3638
sentencepiece
39+
seqio
3740
tensorboard-plugin-profile
3841
tensorboardx
3942
tensorflow-datasets
4043
tensorflow-text
4144
tensorflow
4245
tiktoken
43-
tokamax
46+
tokamax!=0.1.0
4447
transformers
4548
uvloop
4649
qwix
47-
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
48-
mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
-r requirements.txt
2+
google-metrax
3+
ipykernel
4+
kagglehub
5+
papermill
6+
perfetto

src/dependencies/requirements/base_requirements/tpu-base-requirements.txt renamed to src/dependencies/requirements/base_requirements/tpu-requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
11
-r requirements.txt
2-
google-tunix

0 commit comments

Comments
 (0)