Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ert/analysis/_enif_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def enif_update(
data=smoother_snapshot.csv,
extra=smoother_snapshot.extra,
),
posterior_id=str(posterior_storage.id),
ensemble_id=str(posterior_storage.id),
)
)
return smoother_snapshot
Expand Down
2 changes: 1 addition & 1 deletion src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def log_warning(
data=smoother_snapshot.csv,
extra=smoother_snapshot.extra,
),
posterior_id=str(posterior_storage.id),
ensemble_id=str(posterior_storage.id),
)
)
return smoother_snapshot
2 changes: 1 addition & 1 deletion src/ert/analysis/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,4 @@ class AnalysisErrorEvent(AnalysisEvent):
class AnalysisCompleteEvent(AnalysisEvent):
event_type: Literal["AnalysisCompleteEvent"] = "AnalysisCompleteEvent"
data: DataSection
posterior_id: str
ensemble_id: str
20 changes: 17 additions & 3 deletions src/ert/run_models/update_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import functools
import uuid

import polars as pl

from ert.analysis import build_strategy_map, smoother_update
from ert.analysis._update_commons import ErtAnalysisError
from ert.analysis.event import (
Expand Down Expand Up @@ -30,6 +32,7 @@
)
from ert.run_models.run_model import ErtRunError, RunModel, RunModelConfig
from ert.storage import Ensemble, LocalExperiment
from ert.storage.blob_data import BlobStorageData


class UpdateRunModelConfig(RunModelConfig):
Expand Down Expand Up @@ -192,10 +195,21 @@ def send_smoother_event(
)
)
case AnalysisCompleteEvent():
self._storage.get_ensemble(event.posterior_id).save_transition_data(
f"{AnalysisCompleteEvent.__name__}_{uuid.uuid4().hex[:8]}.json",
event.model_dump_json(),
ensemble = self._storage.get_ensemble(event.ensemble_id)
report_id = uuid.uuid4().hex[:8]
file_name = f"observation_report_{report_id}.parquet"
df = pl.DataFrame(
event.data.data,
schema=event.data.header,
orient="row",
)
blob_data = BlobStorageData(
blob_type="observation_report",
uri=file_name,
file_size=0,
ensemble_id=str(ensemble.id),
)
ensemble.save_blob_data(blob_data, df)
self.send_event(
RunModelUpdateEndEvent(
iteration=iteration, run_id=run_id, data=event.data
Expand Down
17 changes: 17 additions & 0 deletions src/ert/storage/blob_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from __future__ import annotations

from pydantic import BaseModel, ConfigDict


class BlobStorageData(BaseModel):
model_config = ConfigDict(extra="forbid")
blob_type: str
uri: str
file_size: int
ensemble_id: str


class MatrixStorageData(BlobStorageData):
blob_type: str = "matrix"
sparse: bool = False
shape: tuple[int, int] = (0, 0)
15 changes: 10 additions & 5 deletions src/ert/storage/local_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
if TYPE_CHECKING:
import numpy.typing as npt

from .blob_data import BlobStorageData
from .local_experiment import LocalExperiment
from .local_storage import LocalStorage

Expand All @@ -50,7 +51,7 @@ class EverestRealizationInfo(TypedDict):


SCALAR_FILENAME = "SCALAR"
TRANSITION_DATA_DIR = "transition"
BLOB_DATA_DIR = "blobs"


class BatchDataframes(TypedDict, total=False):
Expand Down Expand Up @@ -1298,10 +1299,14 @@ def save_everest_realization_info(
)

@require_write
def save_transition_data(self, file_name: str, data: str) -> None:
path = self._path / TRANSITION_DATA_DIR / file_name
Path.mkdir(path.parent, exist_ok=True)
self._storage._write_transaction(path, data.encode("utf-8"))
def save_blob_data(self, blob_data: BlobStorageData, data: pl.DataFrame) -> None:
blob_dir = self._path / BLOB_DATA_DIR
Path.mkdir(blob_dir, exist_ok=True)
self._storage._to_parquet_transaction(blob_dir / blob_data.uri, data)
self._storage._write_transaction(
blob_dir / f"{blob_data.uri}.json",
blob_data.model_dump_json(indent=2).encode("utf-8"),
)

@require_write
def save_batch_dataframes(self, dataframes: BatchDataframes) -> None:
Expand Down
32 changes: 18 additions & 14 deletions tests/ert/unit_tests/run_models/test_update_run_model.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,40 @@
import json
import uuid
from unittest.mock import MagicMock

import polars as pl

from ert.analysis.event import AnalysisCompleteEvent, DataSection
from ert.run_models.update_run_model import UpdateRunModel
from ert.storage.blob_data import BlobStorageData


def test_that_send_smoother_event_persists_observation_report_on_analysis_complete():
model = MagicMock(spec=UpdateRunModel)
model._storage = MagicMock()
mock_ensemble = MagicMock()
ensemble_id = str(uuid.uuid4())
mock_ensemble.id = uuid.UUID(ensemble_id)
model._storage.get_ensemble.return_value = mock_ensemble

posterior_id = str(uuid.uuid4())
data_section = DataSection(
header=["observation_key", "status"],
data=[("OBS_1", "Active"), ("OBS_2", "Deactivated, outlier")],
)
event = AnalysisCompleteEvent(data=data_section, posterior_id=posterior_id)
event = AnalysisCompleteEvent(data=data_section, ensemble_id=ensemble_id)

UpdateRunModel.send_smoother_event(
model, iteration=0, run_id=uuid.uuid4(), event=event
)

model._storage.get_ensemble.assert_called_once_with(posterior_id)
mock_ensemble.save_transition_data.assert_called_once()

_, saved_json = mock_ensemble.save_transition_data.call_args[0]
parsed = json.loads(saved_json)
assert parsed["posterior_id"] == posterior_id
assert parsed["data"]["header"] == ["observation_key", "status"]
assert parsed["data"]["data"] == [
["OBS_1", "Active"],
["OBS_2", "Deactivated, outlier"],
]
model._storage.get_ensemble.assert_called_once_with(ensemble_id)
mock_ensemble.save_blob_data.assert_called_once()

saved_blob_data, saved_df = mock_ensemble.save_blob_data.call_args[0]
assert isinstance(saved_blob_data, BlobStorageData)
assert saved_blob_data.blob_type == "observation_report"
assert saved_blob_data.ensemble_id == ensemble_id
assert saved_blob_data.uri.startswith("observation_report_")
assert saved_blob_data.uri.endswith(".parquet")
assert isinstance(saved_df, pl.DataFrame)
assert saved_df.columns == ["observation_key", "status"]
assert len(saved_df) == 2
89 changes: 76 additions & 13 deletions tests/ert/unit_tests/storage/test_local_ensemble.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import json
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
Expand All @@ -15,6 +16,7 @@
from ert.config.response_config import InvalidResponseFile
from ert.exceptions import StorageError
from ert.storage import open_storage
from ert.storage.blob_data import BlobStorageData
from ert.storage.local_ensemble import _write_responses_to_storage
from ert.storage.mode import ModeError

Expand Down Expand Up @@ -683,41 +685,102 @@ async def run_test():
asyncio.run(run_test())


def test_that_save_transition_data_writes_file_to_disk(tmp_path):
def test_that_save_blob_data_writes_file_to_disk(tmp_path):
with open_storage(tmp_path, mode="w") as storage:
experiment = storage.create_experiment()
ensemble = storage.create_ensemble(
experiment, ensemble_size=1, iteration=0, name="prior"
)

ensemble.save_transition_data("report.json", '{"key": "value"}')

written = (ensemble._path / "transition" / "report.json").read_text(
encoding="utf-8"
df = pl.DataFrame({"key": ["value"]})
blob_data = BlobStorageData(
blob_type="test",
uri="report.parquet",
file_size=0,
ensemble_id=str(ensemble.id),
)
assert written == '{"key": "value"}'
ensemble.save_blob_data(blob_data, df)

parquet_path = ensemble._path / "blobs" / "report.parquet"
assert parquet_path.exists()
loaded = pl.read_parquet(parquet_path)
assert loaded["key"][0] == "value"


def test_that_save_transition_data_creates_transition_directory(tmp_path):
def test_that_save_blob_data_creates_blobs_directory(tmp_path):
with open_storage(tmp_path, mode="w") as storage:
experiment = storage.create_experiment()
ensemble = storage.create_ensemble(
experiment, ensemble_size=1, iteration=0, name="prior"
)

transition_dir = ensemble._path / "transition"
assert not transition_dir.exists()
blob_dir = ensemble._path / "blobs"
assert not blob_dir.exists()

ensemble.save_transition_data("report.json", "data")
assert transition_dir.is_dir()
df = pl.DataFrame({"x": [1]})
blob_data = BlobStorageData(
blob_type="test",
uri="data.parquet",
file_size=0,
ensemble_id=str(ensemble.id),
)
ensemble.save_blob_data(blob_data, df)
assert blob_dir.is_dir()


def test_that_save_transition_data_raises_in_read_mode(tmp_path):
def test_that_save_blob_data_raises_in_read_mode(tmp_path):
with open_storage(tmp_path, mode="w") as storage:
experiment = storage.create_experiment()
storage.create_ensemble(experiment, ensemble_size=1, iteration=0, name="prior")

with open_storage(tmp_path, mode="r") as storage:
ensemble = next(iter(storage.ensembles))
df = pl.DataFrame({"x": [1]})
blob_data = BlobStorageData(
blob_type="test",
uri="data.parquet",
file_size=0,
ensemble_id="fake",
)
with pytest.raises(ModeError):
ensemble.save_transition_data("report.json", "data")
ensemble.save_blob_data(blob_data, df)


def test_that_save_blob_data_writes_parquet_and_json_to_disk(tmp_path):
with open_storage(tmp_path, mode="w") as storage:
experiment = storage.create_experiment()
ensemble = storage.create_ensemble(
experiment, ensemble_size=1, iteration=0, name="prior"
)

df = pl.DataFrame(
{
"observation_key": ["OBS_1", "OBS_2"],
"status": ["Active", "Deactivated, outlier"],
"value": [1.5, 2.0],
}
)

blob_data = BlobStorageData(
blob_type="observation_report",
uri="observation_report.parquet",
file_size=0,
ensemble_id=str(ensemble.id),
)

ensemble.save_blob_data(blob_data, df)

parquet_path = ensemble._path / "blobs" / "observation_report.parquet"
json_path = ensemble._path / "blobs" / "observation_report.parquet.json"
assert parquet_path.exists()
assert json_path.exists()

loaded_df = pl.read_parquet(parquet_path)
assert loaded_df.columns == ["observation_key", "status", "value"]
assert len(loaded_df) == 2
assert loaded_df["observation_key"][0] == "OBS_1"

metadata = json.loads(json_path.read_text(encoding="utf-8"))
assert metadata["blob_type"] == "observation_report"
assert metadata["uri"] == "observation_report.parquet"
assert metadata["ensemble_id"] == str(ensemble.id)
Loading