Skip to content

Commit 6a5a445

Browse files
authored
Merge pull request #442 from physycom/refactor
Refactor SQL saving
2 parents 5a9a0d0 + 714e021 commit 6a5a445

2 files changed

Lines changed: 224 additions & 112 deletions

File tree

src/dsf/mobility/FirstOrderDynamics.cpp

Lines changed: 162 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,7 @@ namespace dsf::mobility {
850850
spdlog::debug("There are {} agents left in the list.", m_agents.size());
851851
}
852852

853+
// Init Street Data methods
853854
void FirstOrderDynamics::m_initStreetTable() const {
854855
if (!this->database()) {
855856
throw std::runtime_error(
@@ -873,6 +874,53 @@ namespace dsf::mobility {
873874

874875
spdlog::info("Initialized road_data table in the database.");
875876
}
877+
void FirstOrderDynamics::m_saveStreetDataSQL(
878+
const std::string& datetime,
879+
const std::int64_t time_step,
880+
const std::int64_t simulation_id,
881+
tbb::concurrent_vector<StreetDataRecord> streetDataRecords) const {
882+
if (streetDataRecords.empty()) {
883+
spdlog::debug("No street data records to save for time step {}.", time_step);
884+
return;
885+
}
886+
SQLite::Statement insertStmt(
887+
*this->database(),
888+
"INSERT INTO road_data (datetime, time_step, simulation_id, street_id, "
889+
"coil, density_vpk, avg_speed_kph, std_speed_kph, n_observations, counts, "
890+
"queue_length) "
891+
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)");
892+
893+
for (auto const& record : streetDataRecords) {
894+
insertStmt.bind(1, datetime);
895+
insertStmt.bind(2, time_step);
896+
insertStmt.bind(3, simulation_id);
897+
insertStmt.bind(4, static_cast<std::int64_t>(record.streetId));
898+
if (record.coilName.has_value()) {
899+
insertStmt.bind(5, record.coilName.value());
900+
} else {
901+
insertStmt.bind(5);
902+
}
903+
insertStmt.bind(6, record.density);
904+
if (record.avgSpeed.has_value()) {
905+
insertStmt.bind(7, record.avgSpeed.value());
906+
insertStmt.bind(8, record.stdSpeed.value());
907+
} else {
908+
insertStmt.bind(7);
909+
insertStmt.bind(8);
910+
}
911+
insertStmt.bind(9, static_cast<std::int64_t>(record.nObservations.value_or(0)));
912+
if (record.counts.has_value()) {
913+
insertStmt.bind(10, static_cast<std::int64_t>(record.counts.value()));
914+
} else {
915+
insertStmt.bind(10);
916+
}
917+
insertStmt.bind(11, static_cast<std::int64_t>(record.queueLength));
918+
insertStmt.exec();
919+
insertStmt.reset();
920+
}
921+
}
922+
// End Street Data methods
923+
// Init Avg Stats methods
876924
void FirstOrderDynamics::m_initAvgStatsTable() const {
877925
if (!this->database()) {
878926
throw std::runtime_error(
@@ -896,6 +944,45 @@ namespace dsf::mobility {
896944

897945
spdlog::info("Initialized avg_stats table in the database.");
898946
}
947+
void FirstOrderDynamics::m_saveAvgStatsSQL(const std::string& datetime,
948+
const std::int64_t time_step,
949+
const std::int64_t simulation_id,
950+
const std::size_t n_valid_edges,
951+
const double mean_speed,
952+
const double std_speed,
953+
const double mean_density,
954+
const double std_density,
955+
const double mean_traveltime,
956+
const double meanQueueLength) const {
957+
SQLite::Statement insertStmt(
958+
*this->database(),
959+
"INSERT INTO avg_stats ("
960+
"simulation_id, datetime, time_step, n_ghost_agents, n_agents, "
961+
"mean_speed_kph, std_speed_kph, mean_density_vpk, std_density_vpk, "
962+
"mean_travel_time_s, mean_queue_length) "
963+
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)");
964+
insertStmt.bind(1, simulation_id);
965+
insertStmt.bind(2, datetime);
966+
insertStmt.bind(3, time_step);
967+
insertStmt.bind(4, static_cast<std::int64_t>(m_agents.size()));
968+
insertStmt.bind(5, static_cast<std::int64_t>(this->nAgents()));
969+
970+
if (n_valid_edges > 0) {
971+
insertStmt.bind(6, mean_speed);
972+
insertStmt.bind(7, std_speed);
973+
insertStmt.bind(10, mean_traveltime);
974+
} else {
975+
insertStmt.bind(6);
976+
insertStmt.bind(7);
977+
insertStmt.bind(10);
978+
}
979+
insertStmt.bind(8, mean_density);
980+
insertStmt.bind(9, std_density);
981+
insertStmt.bind(11, meanQueueLength);
982+
insertStmt.exec();
983+
}
984+
// End Avg Stats methods
985+
// Init Travel Data methods
899986
void FirstOrderDynamics::m_initTravelDataTable() const {
900987
if (!this->database()) {
901988
throw std::runtime_error(
@@ -913,6 +1000,28 @@ namespace dsf::mobility {
9131000

9141001
spdlog::info("Initialized travel_data table in the database.");
9151002
}
1003+
void FirstOrderDynamics::m_saveTravelDataSQL(
1004+
const std::string& datetime,
1005+
const std::int64_t time_step,
1006+
const std::int64_t simulation_id,
1007+
tbb::concurrent_vector<std::pair<double, double>> travelDTs) const {
1008+
SQLite::Statement insertStmt(*this->database(),
1009+
"INSERT INTO travel_data (datetime, time_step, "
1010+
"simulation_id, distance_m, travel_time_s) "
1011+
"VALUES (?, ?, ?, ?, ?)");
1012+
1013+
for (auto const& [distance, time] : travelDTs) {
1014+
insertStmt.bind(1, datetime);
1015+
insertStmt.bind(2, time_step);
1016+
insertStmt.bind(3, simulation_id);
1017+
insertStmt.bind(4, distance);
1018+
insertStmt.bind(5, time);
1019+
insertStmt.exec();
1020+
insertStmt.reset();
1021+
}
1022+
}
1023+
// End Travel Data methods
1024+
// Init Agent Data methods
9161025
void FirstOrderDynamics::m_initAgentDataTable() const {
9171026
if (!this->database()) {
9181027
throw std::runtime_error(
@@ -930,6 +1039,33 @@ namespace dsf::mobility {
9301039

9311040
spdlog::info("Initialized agent_data table in the database.");
9321041
}
1042+
void FirstOrderDynamics::m_saveAgentDataSQL(
1043+
const std::int64_t time_step,
1044+
const std::int64_t simulation_id,
1045+
tbb::concurrent_unordered_map<Id,
1046+
std::vector<std::tuple<Id, std::time_t, std::time_t>>>
1047+
agentDataRecords) const {
1048+
if (agentDataRecords.empty()) {
1049+
spdlog::debug("No agent data records to save for time step {}.", time_step);
1050+
return;
1051+
}
1052+
SQLite::Statement insertStmt(*this->database(),
1053+
"INSERT INTO agent_data (simulation_id, "
1054+
"agent_id, edge_id, time_step_in, time_step_out)"
1055+
"VALUES (?, ?, ?, ?, ?)");
1056+
for (auto const& [edge_id, data] : agentDataRecords) {
1057+
for (auto const& [agent_id, ts_in, ts_out] : data) {
1058+
insertStmt.bind(1, simulation_id);
1059+
insertStmt.bind(2, static_cast<std::int64_t>(agent_id));
1060+
insertStmt.bind(3, static_cast<std::int64_t>(edge_id));
1061+
insertStmt.bind(4, static_cast<std::int64_t>(ts_in));
1062+
insertStmt.bind(5, static_cast<std::int64_t>(ts_out));
1063+
insertStmt.exec();
1064+
insertStmt.reset();
1065+
}
1066+
}
1067+
}
1068+
// End Agent Data methods
9331069
void FirstOrderDynamics::m_dumpSimInfo() const {
9341070
// Dump simulation info (parameters) to the database, if connected
9351071
if (!this->database()) {
@@ -1471,18 +1607,6 @@ namespace dsf::mobility {
14711607
m_savingInterval.has_value() &&
14721608
(m_savingInterval.value() == 0 ||
14731609
this->time_step() % m_savingInterval.value() == 0);
1474-
1475-
// Struct to collect street data for batch insert after parallel section
1476-
struct StreetDataRecord {
1477-
Id streetId;
1478-
std::optional<std::string> coilName;
1479-
double density;
1480-
std::optional<double> avgSpeed;
1481-
std::optional<double> stdSpeed;
1482-
std::optional<std::size_t> nObservations;
1483-
std::optional<std::size_t> counts;
1484-
std::size_t queueLength;
1485-
};
14861610
tbb::concurrent_vector<StreetDataRecord> streetDataRecords;
14871611

14881612
spdlog::debug("Init evolve at time {}", this->time_step());
@@ -1594,120 +1718,46 @@ namespace dsf::mobility {
15941718
}
15951719

15961720
// Batch insert street data collected during parallel section
1597-
if (m_bSaveStreetData && !streetDataRecords.empty()) {
1598-
SQLite::Statement insertStmt(
1599-
*this->database(),
1600-
"INSERT INTO road_data (datetime, time_step, simulation_id, street_id, "
1601-
"coil, density_vpk, avg_speed_kph, std_speed_kph, n_observations, counts, "
1602-
"queue_length) "
1603-
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)");
1604-
1605-
for (auto const& record : streetDataRecords) {
1606-
insertStmt.bind(1, datetime);
1607-
insertStmt.bind(2, step);
1608-
insertStmt.bind(3, simulationId);
1609-
insertStmt.bind(4, static_cast<std::int64_t>(record.streetId));
1610-
if (record.coilName.has_value()) {
1611-
insertStmt.bind(5, record.coilName.value());
1612-
} else {
1613-
insertStmt.bind(5);
1614-
}
1615-
insertStmt.bind(6, record.density);
1616-
if (record.avgSpeed.has_value()) {
1617-
insertStmt.bind(7, record.avgSpeed.value());
1618-
insertStmt.bind(8, record.stdSpeed.value());
1619-
} else {
1620-
insertStmt.bind(7);
1621-
insertStmt.bind(8);
1622-
}
1623-
insertStmt.bind(9, static_cast<std::int64_t>(record.nObservations.value_or(0)));
1624-
if (record.counts.has_value()) {
1625-
insertStmt.bind(10, static_cast<std::int64_t>(record.counts.value()));
1626-
} else {
1627-
insertStmt.bind(10);
1628-
}
1629-
insertStmt.bind(11, static_cast<std::int64_t>(record.queueLength));
1630-
insertStmt.exec();
1631-
insertStmt.reset();
1632-
}
1721+
if (m_bSaveStreetData) {
1722+
this->m_saveStreetDataSQL(
1723+
datetime, step, simulationId, std::move(streetDataRecords));
16331724
}
16341725

16351726
if (m_bSaveTravelData && !m_travelDTs.empty()) {
1636-
SQLite::Statement insertStmt(*this->database(),
1637-
"INSERT INTO travel_data (datetime, time_step, "
1638-
"simulation_id, distance_m, travel_time_s) "
1639-
"VALUES (?, ?, ?, ?, ?)");
1640-
1641-
for (auto const& [distance, time] : m_travelDTs) {
1642-
insertStmt.bind(1, datetime);
1643-
insertStmt.bind(2, step);
1644-
insertStmt.bind(3, simulationId);
1645-
insertStmt.bind(4, distance);
1646-
insertStmt.bind(5, time);
1647-
insertStmt.exec();
1648-
insertStmt.reset();
1649-
}
1727+
this->m_saveTravelDataSQL(datetime, step, simulationId, std::move(m_travelDTs));
16501728
m_travelDTs.clear();
16511729
}
16521730

16531731
if (m_bSaveAgentData) {
1654-
auto agentData = Street::agentData();
1655-
SQLite::Statement insertStmt(*this->database(),
1656-
"INSERT INTO agent_data (simulation_id, "
1657-
"agent_id, edge_id, time_step_in, time_step_out)"
1658-
"VALUES (?, ?, ?, ?, ?)");
1659-
for (auto const& [edge_id, data] : agentData) {
1660-
for (auto const& [agent_id, ts_in, ts_out] : data) {
1661-
insertStmt.bind(1, simulationId);
1662-
insertStmt.bind(2, static_cast<std::int64_t>(agent_id));
1663-
insertStmt.bind(3, static_cast<std::int64_t>(edge_id));
1664-
insertStmt.bind(4, static_cast<std::int64_t>(ts_in));
1665-
insertStmt.bind(5, static_cast<std::int64_t>(ts_out));
1666-
insertStmt.exec();
1667-
insertStmt.reset();
1668-
}
1669-
}
1732+
this->m_saveAgentDataSQL(step, simulationId, Street::agentData());
16701733
}
16711734

16721735
if (m_bSaveAverageStats) { // Average Stats Table
1736+
double meanSpeed{0.}, stdSpeed{0.}, meanDensity{0.}, stdDensity{0.},
1737+
meanTravelTime{0.}, meanQueueLength{0.};
16731738
auto const validEdges = nValidEdges.load();
16741739
auto const edgeCount = static_cast<double>(numEdges);
1675-
auto const meanDensity = mean_density.load() / edgeCount;
1676-
auto const meanQueueLength = mean_queue_length.load() / edgeCount;
1677-
auto const densityVariance =
1678-
std::max(0.0, std_density.load() / edgeCount - meanDensity * meanDensity);
1679-
1680-
SQLite::Statement insertStmt(
1681-
*this->database(),
1682-
"INSERT INTO avg_stats ("
1683-
"simulation_id, datetime, time_step, n_ghost_agents, n_agents, "
1684-
"mean_speed_kph, std_speed_kph, mean_density_vpk, std_density_vpk, "
1685-
"mean_travel_time_s, mean_queue_length) "
1686-
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)");
1687-
insertStmt.bind(1, simulationId);
1688-
insertStmt.bind(2, datetime);
1689-
insertStmt.bind(3, step);
1690-
insertStmt.bind(4, static_cast<std::int64_t>(m_agents.size()));
1691-
insertStmt.bind(5, static_cast<std::int64_t>(this->nAgents()));
1692-
16931740
if (validEdges > 0) {
1694-
auto const validEdgeCount = static_cast<double>(validEdges);
1695-
auto const meanSpeed = mean_speed.load() / validEdgeCount;
1696-
auto const meanTravelTime = mean_traveltime.load() / validEdgeCount;
1697-
auto const speedVariance =
1698-
std::max(0.0, std_speed.load() / validEdgeCount - meanSpeed * meanSpeed);
1699-
insertStmt.bind(6, meanSpeed);
1700-
insertStmt.bind(7, std::sqrt(speedVariance));
1701-
insertStmt.bind(10, meanTravelTime);
1702-
} else {
1703-
insertStmt.bind(6);
1704-
insertStmt.bind(7);
1705-
insertStmt.bind(10);
1741+
meanSpeed = mean_speed.load() / validEdges;
1742+
stdSpeed = std::sqrt(
1743+
std::max(0.0, std_speed.load() / validEdges - meanSpeed * meanSpeed));
1744+
meanDensity = mean_density.load() / edgeCount;
1745+
stdDensity = std::sqrt(
1746+
std::max(0.0, std_density.load() / edgeCount - meanDensity * meanDensity));
1747+
meanTravelTime = mean_traveltime.load() / validEdges;
1748+
meanQueueLength = mean_queue_length.load() / edgeCount;
17061749
}
1707-
insertStmt.bind(8, meanDensity);
1708-
insertStmt.bind(9, std::sqrt(densityVariance));
1709-
insertStmt.bind(11, meanQueueLength);
1710-
insertStmt.exec();
1750+
1751+
this->m_saveAvgStatsSQL(datetime,
1752+
step,
1753+
simulationId,
1754+
validEdges,
1755+
meanSpeed,
1756+
stdSpeed,
1757+
meanDensity,
1758+
stdDensity,
1759+
meanTravelTime,
1760+
meanQueueLength);
17111761
}
17121762

17131763
if (transaction.has_value()) {

0 commit comments

Comments
 (0)