Skip to content

Commit 7863722

Browse files
committed
some proper fix to the code
1 parent b127396 commit 7863722

5 files changed

Lines changed: 33 additions & 14 deletions

File tree

src/diffpy/srmise/applications/extract.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,21 @@
1818
import numpy as np
1919

2020

21+
def _baseline_namespace():
22+
"""Return the baseline classes supported by the CLI."""
23+
from diffpy.srmise.baselines.arbitrary import Arbitrary
24+
from diffpy.srmise.baselines.fromsequence import FromSequence
25+
from diffpy.srmise.baselines.nanospherical import NanoSpherical
26+
from diffpy.srmise.baselines.polynomial import Polynomial
27+
28+
return {
29+
"Arbitrary": Arbitrary,
30+
"FromSequence": FromSequence,
31+
"NanoSpherical": NanoSpherical,
32+
"Polynomial": Polynomial,
33+
}
34+
35+
2136
def main():
2237
"""Default SrMise entry-point."""
2338

@@ -483,12 +498,16 @@ def main():
483498

484499
bl = NanoSpherical()
485500
options.baseline = parsepars(bl, options.bspherical)
486-
try:
487-
options.baseline = eval("baselines." + options.baseline)
488-
except Exception as err:
489-
print(err)
490-
print("Could not create baseline '%s'. Exiting." % options.baseline)
491-
return
501+
try:
502+
options.baseline = eval(
503+
options.baseline,
504+
{"__builtins__": {}},
505+
_baseline_namespace(),
506+
)
507+
except Exception as err:
508+
print(err)
509+
print("Could not create baseline '%s'. Exiting." % options.baseline)
510+
return
492511

493512
filename = args[0]
494513

src/diffpy/srmise/modelcluster.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1538,8 +1538,8 @@ def prune(self):
15381538
# Create model with ith peak removed, and distant peaks effectively fixed
15391539
lo = max(i - peak_range, 0)
15401540
hi = min(i + peak_range + 1, len(best_model))
1541-
check_models[i] = type(best_model)(best_model[lo:i]).copy()
1542-
check_models[i].extend(type(best_model)(best_model[i + 1 : hi]).copy())
1541+
check_models[i] = best_model[lo:i].copy()
1542+
check_models[i].extend(best_model[i + 1: hi].copy())
15431543
prune_mc.model = check_models[i]
15441544

15451545
msg = [

src/diffpy/srmise/modelevaluators/aic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def evaluate(self, fit, count_fixed=False, kshift=0):
8585
if self.chisq is None:
8686
self.chisq = self.chi_squared(fit.value(), fit.y_cluster, fit.error_cluster)
8787

88-
self.stat = self.chisq + self.parpenalty(k, n)
88+
self.stat = self.chisq + self.parpenalty(k)
8989

9090
return self.stat
9191

@@ -106,7 +106,7 @@ def minpoints(self, npars):
106106

107107
return 1
108108

109-
def parpenalty(self, k, n=None):
109+
def parpenalty(self, k):
110110
"""Returns the cost for adding k parameters to the current model
111111
cluster.
112112
@@ -169,7 +169,7 @@ def growth_justified(self, fit, k_prime):
169169
logger.warning("AIC.growth_justified(): too few data to evaluate quality reliably.")
170170
n = self.minpoints(k_actual)
171171

172-
penalty = self.parpenalty(k_test, n) - self.parpenalty(k_actual, n)
172+
penalty = self.parpenalty(k_test) - self.parpenalty(k_actual)
173173

174174
return penalty < self.chisq
175175

src/diffpy/srmise/peaks/gaussianoverr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,7 @@ def max(self, pars):
534534

535535
guesspars = [[2.7, 0.15, 5], [3.7, 0.3, 5]]
536536
guess_peaks = Peaks([pf.actualize(p, "pwa") for p in guesspars])
537-
cluster = ModelCluster(guess_peaks, r, y, err, None, AICc, [pf])
537+
cluster = ModelCluster(guess_peaks, None, r, y, err, None, AICc, [pf])
538538

539539
qual1 = cluster.quality()
540540
print(qual1.stat)

src/diffpy/srmise/peaks/terminationripples.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -416,12 +416,12 @@ def extend_grid(self, r, dr):
416416

417417
guesspars = [[2.7, 0.15, 5], [3.7, 0.3, 5]]
418418
guess_peaks = Peaks([pf2.actualize(p, "pwa") for p in guesspars])
419-
cluster = ModelCluster(guess_peaks, r, y_ripple, err, None, AICc, [pf2])
419+
cluster = ModelCluster(guess_peaks, None, r, y_ripple, err, None, AICc, [pf2])
420420

421421
qual1 = cluster.quality()
422422
print(qual1.stat)
423423
cluster.fit()
424-
yfit = cluster.calc()
424+
yfit = cluster.valuebl()
425425
qual2 = cluster.quality()
426426
print(qual2.stat)
427427

0 commit comments

Comments
 (0)