Skip to content

Commit e58739c

Browse files
committed
ensure calls to .sample() are deterministic in CI
1 parent d36180c commit e58739c

2 files changed

Lines changed: 10 additions & 1 deletion

File tree

.github/workflows/main.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ jobs:
2727
# - hw_4.ipynb # has incomplete code
2828
- hw_5.ipynb
2929
- lecture_0.ipynb
30-
# - lecture_1.ipynb # uses sample()
30+
- lecture_1.ipynb
3131
- lecture_2.ipynb
3232
- lecture_3.ipynb
3333
- lecture_4.ipynb

scripts/diffable.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@ def is_vid(cell):
1616
return text == "<IPython.core.display.Video object>"
1717

1818

19+
def fix_sample(line):
20+
"""Ensure calls to .sample() are deterministic by passing in a seed value
21+
22+
https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.sample.html"""
23+
return re.sub(r"\.sample\((\d+)\)", r".sample(\1, random_state=1)", line)
24+
25+
1926
input_str = sys.stdin.read()
2027
notebook = json.loads(input_str)
2128

@@ -34,6 +41,8 @@ def is_vid(cell):
3441
if cell["source"][0].startswith("!"):
3542
cell["outputs"] = []
3643

44+
cell["source"] = [fix_sample(line) for line in cell["source"]]
45+
3746
# filter out pip upgrade warnings
3847
cell["outputs"] = [line for line in cell["outputs"] if not is_pip_upgrade_msg(line)]
3948

0 commit comments

Comments
 (0)