Skip to content

Commit e6775a1

Browse files
authored
Add more bindings and tests (#444)
* Add more bindings and tests * Copilot fixes
1 parent 6ce9b5c commit e6775a1

6 files changed

Lines changed: 312 additions & 6 deletions

File tree

.github/workflows/pytest.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@ jobs:
2626
sudo apt update
2727
sudo apt install -y cmake build-essential doxygen libtbb-dev
2828
python -m pip install --upgrade pip
29-
pip install pytest pytest-cov
29+
pip install pytest pytest-cov pandas polars
3030
3131
- name: Install package
3232
run: |
3333
pip install -e .
3434
3535
- name: Run tests with pytest
3636
run: |
37-
pytest test/*py -v --cov=src/dsf --cov-report=term-missing
37+
pytest test -v --cov=src/dsf --cov-report=term-missing -o "python_files=Test_*.py test_*.py Bind_*.py"

src/dsf/bindings.cpp

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,45 @@ PYBIND11_MODULE(dsf_cpp, m) {
7878

7979
m.def("get_log_level", &spdlog::get_level, "Get the current global log level");
8080

81+
// Bind Street class to mobility submodule
82+
pybind11::class_<dsf::mobility::Street>(mobility, "Street")
83+
.def(
84+
"id", &dsf::mobility::Street::id, dsf::g_docstrings.at("dsf::Edge::id").c_str())
85+
.def("source",
86+
&dsf::mobility::Street::source,
87+
dsf::g_docstrings.at("dsf::Edge::source").c_str())
88+
.def("target",
89+
&dsf::mobility::Street::target,
90+
dsf::g_docstrings.at("dsf::Edge::target").c_str())
91+
.def("geometry",
92+
&dsf::mobility::Street::geometry,
93+
dsf::g_docstrings.at("dsf::Edge::geometry").c_str())
94+
.def("name",
95+
&dsf::mobility::Street::name,
96+
dsf::g_docstrings.at("dsf::mobility::Road::name").c_str())
97+
.def("length",
98+
&dsf::mobility::Street::length,
99+
dsf::g_docstrings.at("dsf::mobility::Road::length").c_str())
100+
.def("maxSpeed",
101+
&dsf::mobility::Street::maxSpeed,
102+
dsf::g_docstrings.at("dsf::mobility::Road::maxSpeed").c_str());
103+
104+
// Bind RoadJunction class to mobility submodule
105+
pybind11::class_<dsf::mobility::RoadJunction>(mobility, "RoadJunction")
106+
.def("id",
107+
&dsf::mobility::RoadJunction::id,
108+
dsf::g_docstrings.at("dsf::Node::id").c_str())
109+
.def("geometry",
110+
&dsf::mobility::RoadJunction::geometry,
111+
dsf::g_docstrings.at("dsf::Node::geometry").c_str())
112+
.def("capacity",
113+
&dsf::mobility::RoadJunction::capacity,
114+
dsf::g_docstrings.at("dsf::mobility::RoadJunction::capacity").c_str())
115+
.def(
116+
"transportCapacity",
117+
&dsf::mobility::RoadJunction::transportCapacity,
118+
dsf::g_docstrings.at("dsf::mobility::RoadJunction::transportCapacity").c_str());
119+
81120
// Bind Measurement to main module (can be used across different contexts)
82121
pybind11::class_<dsf::Measurement<double>>(m, "Measurement")
83122
.def(pybind11::init<double, double, std::size_t>(),
@@ -120,6 +159,21 @@ PYBIND11_MODULE(dsf_cpp, m) {
120159
.def("nTrafficLights",
121160
&dsf::mobility::RoadNetwork::nTrafficLights,
122161
dsf::g_docstrings.at("dsf::mobility::RoadNetwork::nTrafficLights").c_str())
162+
// Bind node and edge Network accessors, which return a ref or a cost ref
163+
// node should return a RoadJunction and edge should return a Street
164+
.def(
165+
"node",
166+
static_cast<dsf::mobility::RoadJunction& (
167+
dsf::mobility::RoadNetwork::*)(dsf::Id)>(&dsf::mobility::RoadNetwork::node),
168+
pybind11::arg("nodeId"),
169+
pybind11::return_value_policy::reference_internal,
170+
dsf::g_docstrings.at("dsf::Network::node").c_str())
171+
.def("edge",
172+
static_cast<dsf::mobility::Street& (dsf::mobility::RoadNetwork::*)(dsf::Id)>(
173+
&dsf::mobility::RoadNetwork::edge),
174+
pybind11::arg("edgeId"),
175+
pybind11::return_value_policy::reference_internal,
176+
dsf::g_docstrings.at("dsf::Network::edge").c_str())
123177
.def("capacity",
124178
&dsf::mobility::RoadNetwork::capacity,
125179
dsf::g_docstrings.at("dsf::mobility::RoadNetwork::capacity").c_str())
@@ -470,7 +524,6 @@ PYBIND11_MODULE(dsf_cpp, m) {
470524
self.setSpeedFunction(
471525
dsf::SpeedFunction::CUSTOM,
472526
[func_ptr](dsf::mobility::Street const& street) -> double {
473-
// No GIL needed — this is pure C
474527
return func_ptr(street.maxSpeed(), street.density(true));
475528
});
476529
break;
@@ -499,9 +552,9 @@ PYBIND11_MODULE(dsf_cpp, m) {
499552
.def(
500553
"setInitTime",
501554
[](dsf::mobility::FirstOrderDynamics& self, pybind11::object datetime_obj) {
502-
auto const epoch =
503-
pybind11::cast<std::time_t>(datetime_obj.attr("timestamp")());
504-
self.setInitTime(epoch);
555+
auto const epoch_seconds =
556+
pybind11::cast<double>(datetime_obj.attr("timestamp")());
557+
self.setInitTime(static_cast<std::time_t>(epoch_seconds));
505558
},
506559
pybind11::arg("datetime"),
507560
dsf::g_docstrings.at("dsf::Dynamics::setInitTime").c_str())
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import dsf
2+
import dsf_cpp
3+
import pytest
4+
5+
6+
def test_root_module_reexports_cpp_bindings():
7+
assert dsf.__version__ == dsf_cpp.__version__
8+
assert dsf.mobility is dsf_cpp.mobility
9+
assert dsf.mdt is dsf_cpp.mdt
10+
assert hasattr(dsf_cpp, "Measurement")
11+
12+
13+
def test_log_level_round_trip():
14+
original_level = dsf.get_log_level()
15+
try:
16+
dsf.set_log_level(dsf.LogLevel.DEBUG)
17+
assert dsf.get_log_level() == dsf.LogLevel.DEBUG
18+
19+
dsf.set_log_level(dsf.LogLevel.INFO)
20+
assert dsf.get_log_level() == dsf.LogLevel.INFO
21+
finally:
22+
dsf.set_log_level(original_level)
23+
24+
25+
def test_log_to_file_accepts_string_path(tmp_path):
26+
dsf.log_to_file(str(tmp_path / "dsf_bindings.log"))
27+
28+
29+
def test_measurement_properties_are_mutable():
30+
measurement = dsf_cpp.Measurement(12.5, 3.0, 10)
31+
32+
assert measurement.mean == pytest.approx(12.5)
33+
assert measurement.std == pytest.approx(3.0)
34+
assert measurement.n == 10
35+
assert isinstance(measurement.is_valid, bool)
36+
37+
measurement.mean = 7.25
38+
measurement.std = 1.5
39+
measurement.n = 4
40+
measurement.is_valid = False
41+
42+
assert measurement.mean == pytest.approx(7.25)
43+
assert measurement.std == pytest.approx(1.5)
44+
assert measurement.n == 4
45+
assert measurement.is_valid is False

test/bindings/Test_bindings_mdt.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import numpy as np
2+
import pytest
3+
4+
from dsf import mdt
5+
6+
EXPECTED_COLUMNS = [
7+
"uid",
8+
"trajectory_id",
9+
"lon",
10+
"lat",
11+
"timestamp_in",
12+
"timestamp_out",
13+
]
14+
15+
16+
class InvalidShapeFrame:
17+
columns = ["uid", "timestamp", "lat", "lon"]
18+
19+
def to_numpy(self):
20+
return np.array([1.0, 2.0, 3.0], dtype=np.float64)
21+
22+
23+
def test_trajectory_collection_from_pandas_to_pandas(trajectory_pandas_df):
24+
collection = mdt.TrajectoryCollection(trajectory_pandas_df)
25+
collection.filter(0.5, 150.0, 0)
26+
27+
result = collection.to_pandas()
28+
29+
assert list(result.columns) == EXPECTED_COLUMNS
30+
assert len(result) > 0
31+
assert {1, 2}.issubset(set(result["uid"].astype(int).tolist()))
32+
33+
34+
def test_trajectory_collection_from_polars_to_polars(trajectory_polars_df):
35+
polars = pytest.importorskip("polars")
36+
37+
collection = mdt.TrajectoryCollection(trajectory_polars_df)
38+
collection.filter(0.5, 150.0, 0)
39+
40+
result = collection.to_polars()
41+
42+
assert isinstance(result, polars.DataFrame)
43+
assert result.columns == EXPECTED_COLUMNS
44+
assert result.height > 0
45+
assert {1, 2}.issubset(set(result["uid"].cast(polars.Int64).to_list()))
46+
47+
48+
def test_trajectory_collection_to_csv_writes_expected_header(
49+
trajectory_pandas_df, tmp_path
50+
):
51+
collection = mdt.TrajectoryCollection(trajectory_pandas_df)
52+
collection.filter(0.5, 150.0, 0)
53+
54+
output_file = tmp_path / "trajectory_export.csv"
55+
collection.to_csv(str(output_file))
56+
57+
assert output_file.exists()
58+
59+
header = output_file.read_text(encoding="utf-8").splitlines()[0]
60+
assert header == "uid;trajectory_id;lon;lat;timestamp_in;timestamp_out"
61+
62+
63+
def test_trajectory_collection_rejects_non_2d_numpy_input():
64+
with pytest.raises(RuntimeError, match="2D numpy array"):
65+
mdt.TrajectoryCollection(InvalidShapeFrame())
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from datetime import datetime, timezone
2+
3+
import numpy as np
4+
import pytest
5+
6+
from dsf import mobility
7+
8+
9+
def test_road_network_import_counts(loaded_road_network):
10+
assert loaded_road_network.nNodes() == 120
11+
assert loaded_road_network.nEdges() == 436
12+
13+
14+
def test_shortest_path_returns_path_collection(loaded_road_network):
15+
path_map = loaded_road_network.shortestPath(0, 119, mobility.PathWeight.LENGTH)
16+
17+
assert isinstance(path_map, mobility.PathCollection)
18+
assert len(path_map) > 0
19+
assert 0 in path_map
20+
assert 119 not in path_map
21+
assert isinstance(path_map[0], list)
22+
assert len(path_map[0]) > 0
23+
24+
25+
def test_path_collection_dict_protocol_and_explode():
26+
path_collection = mobility.PathCollection()
27+
28+
assert len(path_collection) == 0
29+
30+
path_collection[0] = [1]
31+
path_collection[1] = [2]
32+
33+
assert 0 in path_collection
34+
assert path_collection[0] == [1]
35+
assert set(path_collection.keys()) == {0, 1}
36+
assert path_collection.items() == {0: [1], 1: [2]}
37+
38+
all_paths = path_collection.explode(0, 2)
39+
assert all_paths == [[0, 1, 2]]
40+
41+
with pytest.raises(KeyError):
42+
_ = path_collection[99]
43+
44+
45+
def test_dynamics_set_init_time_accepts_epoch_and_datetime(dynamics):
46+
epoch = 1_700_000_000
47+
dynamics.setInitTime(epoch)
48+
assert dynamics.time() == epoch
49+
50+
dt = datetime.fromtimestamp(epoch + 3600, tz=timezone.utc)
51+
dynamics.setInitTime(dt)
52+
assert dynamics.time() == int(dt.timestamp())
53+
54+
55+
def test_dynamics_accepts_origin_destination_overloads(dynamics):
56+
node_array = np.array([0, 1, 2], dtype=np.uint64)
57+
58+
dynamics.setDestinationNodes([0, 1, 2])
59+
dynamics.setDestinationNodes({0: 0.6, 1: 0.4})
60+
dynamics.setDestinationNodes(node_array)
61+
62+
dynamics.setOriginNodes({0: 1.0, 1: 1.0, 2: 1.0})
63+
dynamics.setOriginNodes(node_array)
64+
65+
66+
def test_dynamics_smoke_step_with_linear_speed(dynamics):
67+
dynamics.setWeightFunction(mobility.PathWeight.LENGTH)
68+
dynamics.setSpeedFunction(mobility.SpeedFunction.LINEAR, 0.8)
69+
dynamics.setDestinationNodes([0, 1, 2])
70+
dynamics.setOriginNodes({3: 1.0, 4: 1.0, 5: 1.0})
71+
72+
dynamics.updatePaths()
73+
dynamics.addAgentsUniformly(1)
74+
75+
assert dynamics.nAgents() == 1
76+
77+
previous_step = dynamics.time_step()
78+
dynamics.evolve(False)
79+
assert dynamics.time_step() == previous_step + 1
80+
81+
82+
def test_compute_betweenness_rejects_invalid_weight(loaded_road_network):
83+
with pytest.raises(Exception, match="Invalid weight function"):
84+
loaded_road_network.computeBetweennessCentralities("invalid")

test/bindings/conftest.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from pathlib import Path
2+
3+
import pytest
4+
5+
from dsf import mobility
6+
7+
TEST_DATA_DIR = Path(__file__).resolve().parents[1] / "data"
8+
MANHATTAN_EDGES = TEST_DATA_DIR / "manhattan_edges.csv"
9+
MANHATTAN_NODES = TEST_DATA_DIR / "manhattan_nodes.csv"
10+
11+
12+
@pytest.fixture(scope="session")
13+
def manhattan_edges_path() -> Path:
14+
return MANHATTAN_EDGES
15+
16+
17+
@pytest.fixture(scope="session")
18+
def manhattan_nodes_path() -> Path:
19+
return MANHATTAN_NODES
20+
21+
22+
@pytest.fixture()
23+
def empty_road_network():
24+
return mobility.RoadNetwork()
25+
26+
27+
@pytest.fixture()
28+
def loaded_road_network(manhattan_edges_path: Path, manhattan_nodes_path: Path):
29+
network = mobility.RoadNetwork()
30+
network.importEdges(str(manhattan_edges_path), ";")
31+
network.importNodeProperties(str(manhattan_nodes_path), ";")
32+
return network
33+
34+
35+
@pytest.fixture()
36+
def dynamics(loaded_road_network):
37+
return mobility.Dynamics(loaded_road_network, seed=69)
38+
39+
40+
@pytest.fixture()
41+
def trajectory_rows():
42+
return {
43+
"uid": [1, 1, 1, 2, 2, 2],
44+
"timestamp": [1000, 1600, 2200, 1000, 1600, 2200],
45+
"lat": [44.4949, 44.4950, 44.4951, 45.4064, 45.4065, 45.4066],
46+
"lon": [11.3426, 11.3427, 11.3428, 11.8767, 11.8768, 11.8769],
47+
}
48+
49+
50+
@pytest.fixture()
51+
def trajectory_pandas_df(trajectory_rows):
52+
pandas = pytest.importorskip("pandas")
53+
return pandas.DataFrame(trajectory_rows)
54+
55+
56+
@pytest.fixture()
57+
def trajectory_polars_df(trajectory_rows):
58+
polars = pytest.importorskip("polars")
59+
return polars.DataFrame(trajectory_rows)

0 commit comments

Comments
 (0)