-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathmodel.py
More file actions
1162 lines (886 loc) · 39.1 KB
/
model.py
File metadata and controls
1162 lines (886 loc) · 39.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from collections import OrderedDict, defaultdict
import copy
import time
import attr
import dask
from dask.distributed import Client
from .variable import VarIntent, VarType
from .process import (
filter_variables,
get_process_cls,
get_target_variable,
RuntimeSignal,
SimulationStage,
)
from .utils import AttrMapping, Frozen, variables_dict
from .formatting import repr_model
def _flatten_keys(key_seq):
"""returns a flat list of keys, i.e., ``('foo', 'bar')`` tuples, from
a nested sequence.
"""
flat_keys = []
for key in key_seq:
if not isinstance(key, tuple):
flat_keys += _flatten_keys(key)
else:
flat_keys.append(key)
return flat_keys
def get_model_variables(p_mapping, **kwargs):
"""Get variables in the model (processes mapping) as a list of
``(process_name, var_name)`` tuples.
**kwargs may be used to return only a subset of the variables.
"""
var_keys = []
for p_name, proc in p_mapping.items():
var_keys += [
(p_name, var_name) for var_name in filter_variables(proc, **kwargs)
]
return var_keys
def get_reverse_lookup(processes_cls):
"""Return a dictionary with process classes as keys and process names
as values.
Additionally, the returned dictionary maps all parent classes
to one (str) or several (list) process names.
"""
reverse_lookup = defaultdict(list)
for p_name, p_cls in processes_cls.items():
# exclude `object` base class from lookup
for cls in p_cls.mro()[:-1]:
reverse_lookup[cls].append(p_name)
return {k: v[0] if len(v) == 1 else v for k, v in reverse_lookup.items()}
def get_global_refs(processes_cls):
"""Return a dictionary with global names as keys and
('process_name', var) tuples (or lists of those tuples) as values.
"""
temp_refs = defaultdict(list)
for p_name, p_cls in processes_cls.items():
for var in variables_dict(p_cls).values():
global_name = var.metadata.get("global_name")
if var.metadata["var_type"] != VarType.GLOBAL and global_name is not None:
temp_refs[global_name].append((p_name, var))
global_refs = {k: v if len(v) > 1 else v[0] for k, v in temp_refs.items()}
return global_refs
class _ModelBuilder:
"""Used to iteratively build a new model.
This builder implements the following tasks:
- Attach the model instance to each process and assign their given
name in model.
- Create a "state", i.e., a mapping used to store active simulation data
- Create a cache for fastpath access to the (meta)data of all variables
defined in the model
- Define for each variable of the model its corresponding key
(in state or on-demand)
- Find variables that are model inputs
- Find process dependencies and sort processes (DAG)
- Find the processes that implement the method relative to each
step of a simulation
"""
def __init__(self, processes_cls):
self._processes_cls = processes_cls
self._processes_obj = {k: cls() for k, cls in processes_cls.items()}
self._reverse_lookup = get_reverse_lookup(processes_cls)
self._all_vars = get_model_variables(processes_cls)
self._global_vars = get_model_variables(processes_cls, var_type=VarType.GLOBAL)
self._global_refs = get_global_refs(processes_cls)
self._input_vars = None
self._dep_processes = None
self._sorted_processes = None
# a cache for group keys
self._group_keys = {}
def bind_processes(self, model_obj):
for p_name, p_obj in self._processes_obj.items():
p_obj.__xsimlab_model__ = model_obj
p_obj.__xsimlab_name__ = p_name
def set_state(self):
state = {}
# bind state to each process in the model
for p_obj in self._processes_obj.values():
p_obj.__xsimlab_state__ = state
return state
def create_variable_cache(self):
"""Create a cache for fastpath access to the (meta)data of all
variables defined in the model.
"""
var_cache = {}
for p_name, p_cls in self._processes_cls.items():
for v_name, attrib in variables_dict(p_cls).items():
var_cache[(p_name, v_name)] = {
"name": f"{p_name}__{v_name}",
"attrib": attrib,
"metadata": attrib.metadata.copy(),
"value": None,
}
# retrieve/update metadata for global variables
for key in self._global_vars:
metadata = var_cache[key]["metadata"]
_, ref_var = self._get_global_ref(var_cache[key]["attrib"])
ref_metadata = {
k: v for k, v in ref_var.metadata.items() if k not in metadata
}
metadata.update(ref_metadata)
return var_cache
def _get_global_ref(self, var):
"""Return the reference to a global variable as a ('process_name', var) tuple
(check that a reference exists and is unique).
"""
global_name = var.metadata["global_name"]
ref = self._global_refs.get(global_name)
if ref is None:
raise KeyError(
f"No variable with global name '{global_name}' found in model"
)
elif isinstance(ref, list):
raise ValueError(
f"Found multiple variables with global name '{global_name}' in model: "
", ".join([str(r) for r in ref])
)
return ref
def _get_foreign_ref(self, p_name, var):
"""Return the reference to a foreign variable as a ('process_name', var) tuple
(check that a reference exists and is unique).
"""
target_p_cls, target_var = get_target_variable(var)
target_p_name = self._reverse_lookup.get(target_p_cls, None)
if target_p_name is None:
raise KeyError(
f"Process class '{target_p_cls.__name__}' "
"missing in Model but required "
f"by foreign variable '{var.name}' "
f"declared in process '{p_name}'"
)
elif isinstance(target_p_name, list):
raise ValueError(
"Process class {!r} required by foreign variable '{}.{}' "
"is used (possibly via one its child classes) by multiple "
"processes: {}".format(
target_p_cls.__name__,
p_name,
var.name,
", ".join(["{!r}".format(n) for n in target_p_name]),
)
)
# go through global reference
if target_var.metadata["var_type"] == VarType.GLOBAL:
target_p_name, target_var = self._get_global_ref(target_var)
return target_p_name, target_var
def _get_var_key(self, p_name, var):
"""Get state and/or on-demand keys for variable `var` declared in
process `p_name`.
Returned key(s) are either None (if no key), a tuple or a list
of tuples (for group variables).
A key tuple looks like ``('foo', 'bar')`` where 'foo' is the
name of any process in the model and 'bar' is the name of a
variable declared in that process.
"""
state_key = None
od_key = None
var_type = var.metadata["var_type"]
if var_type in (VarType.VARIABLE, VarType.INDEX, VarType.OBJECT):
state_key = (p_name, var.name)
elif var_type == VarType.ON_DEMAND:
od_key = (p_name, var.name)
elif var_type == VarType.FOREIGN:
state_key, od_key = self._get_var_key(*self._get_foreign_ref(p_name, var))
elif var_type == VarType.GLOBAL:
state_key, od_key = self._get_var_key(*self._get_global_ref(var))
elif var_type in (VarType.GROUP, VarType.GROUP_DICT):
var_group = var.metadata["group"]
state_key, od_key = self._get_group_var_keys(var_group)
return state_key, od_key
def _get_group_var_keys(self, group):
"""Get from cache or find model-wise state and on-demand keys
for all variables related to a group (except group variables).
"""
if group in self._group_keys:
return self._group_keys[group]
state_keys = []
od_keys = []
for p_name, p_obj in self._processes_obj.items():
for var in filter_variables(p_obj, group=group).values():
state_key, od_key = self._get_var_key(p_name, var)
if state_key is not None:
state_keys.append(state_key)
if od_key is not None:
od_keys.append(od_key)
self._group_keys[group] = state_keys, od_keys
return state_keys, od_keys
def set_process_keys(self):
"""Find state and/or on-demand keys for all variables in a model and
store them in their respective process, i.e., the following
attributes:
__xsimlab_state_keys__ (state keys)
__xsimlab_od_keys__ (on-demand keys)
"""
for p_name, p_obj in self._processes_obj.items():
for var in filter_variables(p_obj).values():
state_key, od_key = self._get_var_key(p_name, var)
if state_key is not None:
p_obj.__xsimlab_state_keys__[var.name] = state_key
if od_key is not None:
p_obj.__xsimlab_od_keys__[var.name] = od_key
def ensure_no_intent_conflict(self):
"""Raise an error if more than one variable with
intent='out' targets the same variable.
"""
def filter_out(var):
return (
var.metadata["intent"] == VarIntent.OUT
and var.metadata["var_type"] != VarType.ON_DEMAND
)
targets = defaultdict(list)
for p_name, p_obj in self._processes_obj.items():
for var in filter_variables(p_obj, func=filter_out).values():
target_key = p_obj.__xsimlab_state_keys__.get(var.name)
targets[target_key].append((p_name, var.name))
conflicts = {k: v for k, v in targets.items() if len(v) > 1}
if conflicts:
conflicts_str = {
k: " and ".join(["'{}.{}'".format(*i) for i in v])
for k, v in conflicts.items()
}
msg = "\n".join(
[f"'{'.'.join(k)}' set by: {v}" for k, v in conflicts_str.items()]
)
raise ValueError(f"Conflict(s) found in given variable intents:\n{msg}")
def get_variables(self, **kwargs):
if not len(kwargs):
return self._all_vars
else:
return get_model_variables(self._processes_cls, **kwargs)
def get_input_variables(self):
"""Get all input variables in the model as a list of
``(process_name, var_name)`` tuples.
Model input variables meet the following conditions:
- model-wise (i.e., in all processes), there is no variable with
intent='out' targeting those variables (in state keys).
- although group variables always have intent='in', they are not
model inputs.
"""
def filter_in(var):
return (
var.metadata["var_type"] != VarType.GROUP
and var.metadata["var_type"] != VarType.GROUP_DICT
and var.metadata["intent"] != VarIntent.OUT
)
def filter_out(var):
return var.metadata["intent"] == VarIntent.OUT
in_keys = []
out_keys = []
for p_obj in self._processes_obj.values():
in_keys += [
p_obj.__xsimlab_state_keys__.get(var.name)
for var in filter_variables(p_obj, func=filter_in).values()
]
out_keys += [
p_obj.__xsimlab_state_keys__.get(var.name)
for var in filter_variables(p_obj, func=filter_out).values()
]
input_vars = [k for k in set(in_keys) - set(out_keys) if k is not None]
# order consistent with variable and process declaration
self._input_vars = [k for k in self.get_variables() if k in input_vars]
return self._input_vars
def get_processes_to_validate(self):
"""Return a dictionary where keys are each process of the model and
values are lists of the names of other processes for which to trigger
validators right after its execution.
Useful for triggering validators of variables defined in other
processes when new values are set through foreign variables.
"""
processes_to_validate = {k: set() for k in self._processes_obj}
for p_name, p_obj in self._processes_obj.items():
out_foreign_vars = filter_variables(
p_obj, var_type=VarType.FOREIGN, intent=VarIntent.OUT
)
for var in out_foreign_vars.values():
pn, _ = p_obj.__xsimlab_state_keys__[var.name]
processes_to_validate[p_name].add(pn)
return {k: list(v) for k, v in processes_to_validate.items()}
def get_process_dependencies(self, custom_dependencies={}):
"""Return a dictionary where keys are each process of the model and
values are lists of the names of dependent processes (or empty
lists for processes that have no dependencies).
Process 1 depends on process 2 if the later declares a
variable (resp. a foreign variable) with intent='out' that
itself (resp. its target variable) is needed in process 1.
"""
self._dep_processes = {k: set() for k in self._processes_obj}
d_keys = {} # all state/on-demand keys for each process
for p_name, p_obj in self._processes_obj.items():
d_keys[p_name] = _flatten_keys(
[
p_obj.__xsimlab_state_keys__.values(),
p_obj.__xsimlab_od_keys__.values(),
]
)
# actually add custom dependencies
for p_name, deps in custom_dependencies.items():
self._dep_processes[p_name].update(deps)
for p_name, p_obj in self._processes_obj.items():
for var in filter_variables(p_obj, intent=VarIntent.OUT).values():
if var.metadata["var_type"] == VarType.ON_DEMAND:
key = p_obj.__xsimlab_od_keys__[var.name]
else:
key = p_obj.__xsimlab_state_keys__[var.name]
for pn in self._processes_obj:
if pn != p_name and key in d_keys[pn]:
self._dep_processes[pn].add(p_name)
self._dep_processes = {k: list(v) for k, v in self._dep_processes.items()}
return self._dep_processes
def _sort_processes(self):
"""Sort processes based on their dependencies (return a list of sorted
process names).
Stack-based depth-first search traversal.
This is based on Tarjan's method for topological sorting.
Part of the code below is copied and modified from:
- dask 0.14.3 (Copyright (c) 2014-2015, Continuum Analytics, Inc.
and contributors)
Licensed under the BSD 3 License
http://dask.pydata.org
"""
ordered = []
# Nodes whose descendents have been completely explored.
# These nodes are guaranteed to not be part of a cycle.
completed = set()
# All nodes that have been visited in the current traversal. Because
# we are doing depth-first search, going "deeper" should never result
# in visiting a node that has already been seen. The `seen` and
# `completed` sets are mutually exclusive; it is okay to visit a node
# that has already been added to `completed`.
seen = set()
for key in self._dep_processes:
if key in completed:
continue
nodes = [key]
while nodes:
# Keep current node on the stack until all descendants are
# visited
cur = nodes[-1]
if cur in completed:
# Already fully traversed descendants of cur
nodes.pop()
continue
seen.add(cur)
# Add direct descendants of cur to nodes stack
next_nodes = []
for nxt in self._dep_processes[cur]:
if nxt not in completed:
if nxt in seen:
# Cycle detected!
cycle = [nxt]
while nodes[-1] != nxt:
cycle.append(nodes.pop())
cycle.append(nodes.pop())
cycle.reverse()
cycle = "->".join(cycle)
raise RuntimeError(
f"Cycle detected in process graph: {cycle}"
)
next_nodes.append(nxt)
if next_nodes:
nodes.extend(next_nodes)
else:
# cur has no more descendants to explore,
# so we're done with it
ordered.append(cur)
completed.add(cur)
seen.remove(cur)
nodes.pop()
return ordered
def get_sorted_processes(self):
self._sorted_processes = OrderedDict(
[(p_name, self._processes_obj[p_name]) for p_name in self._sort_processes()]
)
return self._sorted_processes
class Model(AttrMapping):
"""An immutable collection of process units that together form a
computational model.
This collection is ordered such that the computational flow is
consistent with process inter-dependencies.
Ordering doesn't need to be explicitly provided ; it is dynamically
computed using the processes interfaces.
Processes interfaces are also used for automatically retrieving
the model inputs, i.e., all the variables that require setting a
value before running the model.
"""
active = []
def __init__(self, processes, custom_dependencies={}):
"""
Parameters
----------
processes : dict
Dictionary with process names as keys and classes (decorated with
:func:`process`) as values.
custom_dependencies : dict
Dictionary of custom dependencies.
keys are process names and values iterable of process names that it depends on
Raises
------
:exc:`NoteAProcessClassError`
If values in ``processes`` are not classes decorated with
:func:`process`.
"""
builder = _ModelBuilder({k: get_process_cls(v) for k, v in processes.items()})
builder.bind_processes(self)
builder.set_process_keys()
self._state = builder.set_state()
self._var_cache = builder.create_variable_cache()
self._all_vars = builder.get_variables()
self._all_vars_dict = None
self._index_vars = builder.get_variables(var_type=VarType.INDEX)
self._index_vars_dict = None
self._od_vars = builder.get_variables(var_type=VarType.ON_DEMAND)
builder.ensure_no_intent_conflict()
self._input_vars = builder.get_input_variables()
self._input_vars_dict = None
self._processes_to_validate = builder.get_processes_to_validate()
# clean custom dependencies
self._custom_dependencies = {}
for p_name, c_deps in custom_dependencies.items():
c_deps = {c_deps} if isinstance(c_deps, str) else set(c_deps)
self._custom_dependencies[p_name] = c_deps
self._dep_processes = builder.get_process_dependencies(
self._custom_dependencies
)
self._processes = builder.get_sorted_processes()
super(Model, self).__init__(self._processes)
self._initialized = True
def _get_vars_dict_from_cache(self, attr_name):
dict_attr_name = attr_name + "_dict"
if getattr(self, dict_attr_name) is None:
vars_d = defaultdict(list)
for p_name, var_name in getattr(self, attr_name):
vars_d[p_name].append(var_name)
setattr(self, dict_attr_name, dict(vars_d))
return getattr(self, dict_attr_name)
@property
def all_vars(self):
"""Returns all variables in the model as a list of
``(process_name, var_name)`` tuples (or an empty list).
"""
return self._all_vars
@property
def all_vars_dict(self):
"""Returns all variables in the model as a dictionary of lists of
variable names grouped by process.
"""
return self._get_vars_dict_from_cache("_all_vars")
@property
def index_vars(self):
"""Returns all index variables in the model as a list of
``(process_name, var_name)`` tuples (or an empty list).
"""
return self._index_vars
@property
def index_vars_dict(self):
"""Returns all index variables in the model as a dictionary of lists of
variable names grouped by process.
"""
return self._get_vars_dict_from_cache("_index_vars")
@property
def input_vars(self):
"""Returns all variables that require setting a value before running
the model.
A list of ``(process_name, var_name)`` tuples (or an empty list)
is returned.
"""
return self._input_vars
@property
def input_vars_dict(self):
"""Returns all variables that require setting a value before running
the model.
Unlike :attr:`Model.input_vars`, a dictionary of lists of
variable names grouped by process is returned.
"""
return self._get_vars_dict_from_cache("_input_vars")
@property
def dependent_processes(self):
"""Returns a dictionary where keys are process names and values are
lists of the names of dependent processes.
"""
return self._dep_processes
def visualize(
self, show_only_variable=None, show_inputs=False, show_variables=False
):
"""Render the model as a graph using dot (require graphviz).
Parameters
----------
show_only_variable : tuple, optional
Show only a variable (and all other variables sharing the
same value) given as a tuple ``(process_name, variable_name)``.
Deactivated by default.
show_inputs : bool, optional
If True, show all input variables in the graph (default: False).
Ignored if `show_only_variable` is not None.
show_variables : bool, optional
If True, show also the other variables (default: False).
Ignored if ``show_only_variable`` is not None.
See Also
--------
:func:`dot.dot_graph`
"""
from .dot import dot_graph
return dot_graph(
self,
show_only_variable=show_only_variable,
show_inputs=show_inputs,
show_variables=show_variables,
)
@property
def state(self):
"""Returns a mapping of model variables and their current value.
Mapping keys are in the form of ``('process_name', 'var_name')`` tuples.
This mapping does not include "on demand" variables.
"""
return self._state
def update_state(
self, input_vars, validate=True, ignore_static=False, ignore_invalid_keys=True
):
"""Update the model's state (only input variables) with new values.
Prior to update the model's state, first convert the values for model
variables that have a converter, otherwise copy the values.
Parameters
----------
input_vars : dict_like
A mapping where keys are in the form of
``('process_name', 'var_name')`` tuples and values are
the input values to set in the model state.
validate : bool, optional
If True (default), run the variable validators after setting the
new values.
ignore_static : bool, optional
If True, sets the values even for static variables. Otherwise
(default), raises a ``ValueError`` in order to prevent updating
values of static variables.
ignore_invalid_keys : bool, optional
If True (default), ignores keys in ``input_vars`` that do not
correspond to input variables in the model. Otherwise, raises
a ``KeyError``.
"""
for key, value in input_vars.items():
if key not in self.input_vars:
if ignore_invalid_keys:
continue
else:
raise KeyError(f"{key} is not a valid input variable in model")
var = self._var_cache[key]["attrib"]
if not ignore_static and var.metadata.get("static", False):
raise ValueError(f"Cannot set value for static variable {key}")
if var.converter is not None:
self._state[key] = var.converter(value)
else:
self._state[key] = copy.copy(value)
if validate:
p_names = set([pn for pn, _ in input_vars if pn in self._processes])
self.validate(p_names)
@property
def cache(self):
"""Returns a mapping of model variables and some of their (meta)data cached for
fastpath access.
Mapping keys are in the form of ``('process_name', 'var_name')`` tuples.
"""
return self._var_cache
def update_cache(self, var_key):
"""Update the model's cache for a given model variable.
This is generally not really needed, except for on demand variables
where this might optimize multiple accesses to the variable value between
two simulation stages.
No copy is performed.
Parameters
----------
var_key : tuple
Variable key in the form of a ``('process_name', 'var_name')``
tuple.
"""
p_name, v_name = var_key
self._var_cache[var_key]["value"] = getattr(self._processes[p_name], v_name)
def validate(self, p_names=None):
"""Run the variable validators of all or some of the processes
in the model.
Parameters
----------
p_names : list, optional
Names of the processes to validate. If None is given (default),
validators are run for all processes.
"""
if p_names is None:
processes = self._processes.values()
else:
processes = [self._processes[pn] for pn in p_names]
for p_obj in processes:
attr.validate(p_obj)
def _call_hooks(self, hooks, runtime_context, stage, level, trigger):
try:
event_hooks = hooks[stage][level][trigger]
except KeyError:
return RuntimeSignal.NONE
signals = []
for h in event_hooks:
s = h(self, Frozen(runtime_context), Frozen(self.state))
if s is None:
s = RuntimeSignal(0)
else:
s = RuntimeSignal(s)
signals.append(s)
# Signal with highest value has highest priority
return RuntimeSignal(max([s.value for s in signals]))
def _execute_process(
self, p_obj, stage, runtime_context, hooks, validate, state=None
):
"""Internal process execution method, which calls the process object's
executor.
A state may be passed to the executor instead of using the executor's
state (this is to avoid stateful objects when calling the executor
during execution of a Dask graph).
The process executor returns a partial state (only the variables that
have been updated by the executor, which will be needed for executing
further tasks in the Dask graph).
This method returns this updated state as well as any runtime signal returned
by the hook functions and/or the executor (the one with highest priority).
"""
executor = p_obj.__xsimlab_executor__
p_name = p_obj.__xsimlab_name__
signal_pre = self._call_hooks(hooks, runtime_context, stage, "process", "pre")
if signal_pre.value > 0:
return p_name, ({}, signal_pre)
state_out, signal_out = executor.execute(
p_obj, stage, runtime_context, state=state
)
signal_post = self._call_hooks(hooks, runtime_context, stage, "process", "post")
if signal_post.value > signal_out.value:
signal_out = signal_post
if validate:
self.validate(self._processes_to_validate[p_name])
return p_name, (state_out, signal_out)
def _build_dask_graph(self, execute_args):
"""Build a custom, 'stateless' graph of tasks (process execution) that will
be passed to a Dask scheduler.
"""
def exec_process(p_obj, model_state, exec_outputs):
# update model state with output states from all dependent processes
# gather signals returned by all dependent processes and sort them by highest priority
state = {}
signal = RuntimeSignal.NONE
state.update(model_state)
for _, (state_out, signal_out) in exec_outputs:
state.update(state_out)
if signal_out.value > signal.value:
signal = signal_out
if signal == RuntimeSignal.BREAK:
# received a BREAK signal from the execution of a dependent process
# -> skip execution of current process as well as all downstream processes
# in the graph (by forwarding the signal).
return p_obj.__xsimlab_name__, ({}, signal)
else:
return self._execute_process(p_obj, *execute_args, state=state)
dsk = {}
for p_name, p_deps in self._dep_processes.items():
dsk[p_name] = (exec_process, self._processes[p_name], self._state, p_deps)
# add a node to gather output signals and state from all executed processes
dsk["_gather"] = (
lambda exec_outputs: dict(exec_outputs),
list(self._processes),
)
return dsk
def _merge_exec_outputs(self, exec_outputs) -> RuntimeSignal:
"""Collect and merge process execution outputs (from dask graph).
- combine all output states and update model's state.
- sort all output runtime signals and return the signal with highest priority.
"""
new_state = {}
signal = RuntimeSignal.NONE
# process order matters for properly updating state!
for p_name in self._processes:
state_out, signal_out = exec_outputs[p_name]
new_state.update(state_out)
if signal_out.value > signal.value:
signal = signal_out
self._state.update(new_state)
# need to re-assign the updated state to all processes
# for access between simulation stages (e.g., save snapshots)
for p_obj in self._processes.values():
p_obj.__xsimlab_state__ = self._state
return signal
def _clear_od_cache(self):
"""Clear cached values of on-demand variables."""
for key in self._od_vars:
self._state.pop(key, None)
def execute(
self,
stage,
runtime_context,
hooks=None,
validate=False,
parallel=False,
scheduler=None,
):
"""Run one stage of a simulation.
Parameters
----------
stage : {'initialize', 'run_step', 'finalize_step', 'finalize'}
Name of the simulation stage.
runtime_context : dict
Dictionary containing runtime variables (e.g., time step
duration, current step).
hooks : dict, optional
Runtime hook callables, grouped by simulation stage, level and
trigger pre/post.
validate : bool, optional
If True, run the variable validators in the corresponding
processes after a process (maybe) sets values through its foreign
variables (default: False). This is useful for debugging but
it may significantly impact performance.
parallel : bool, optional
If True, run the simulation stage in parallel using Dask
(default: False).
scheduler : str, optional
Dask's scheduler used to run the stage in parallel
(Dask's threads scheduler is used as failback).
Returns
-------
signal : :class:`RuntimeSignal`
Signal with hightest priority among all signals returned by hook
functions and/or process runtime methods, if any. Otherwise,
returns ``RuntimeSignal.NONE``.
Notes
-----
Even when run in parallel, xarray-simlab ensures that processes will
not be executed before their dependent processes. However, race
conditions or perfomance issues may still occur under certain
circumstances that require extra care. In particular:
- The gain in perfomance when running the processes in parallel
highly depends on the graph structure. It might not be worth the
extra complexity and overhead.
- If a multi-threaded scheduler is used, then the code implemented
in the process classes must be thread-safe. Also, it should release
the Python Global Interpreted Lock (GIL) as much as possible in order
to see a gain in performance.
- Multi-process or distributed schedulers may have very poor performance,
especially when a lot of data (model state) is shared between the model
processes. The way xarray-simlab scatters/gathers this data between the
scheduler and the workers is not optimized at all. Addtionally, those
schedulers may not work well with the given ``hooks`` and/or when the