Skip to content

Commit 9d1974b

Browse files
committed
bugfixes
1 parent 72ade0f commit 9d1974b

File tree

2 files changed

+5
-9
lines changed

2 files changed

+5
-9
lines changed

examples/friedman_demo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def f(x): # Friedman function where first variable is the functional variable
4040
mod.traceplot()
4141
mod.plot(idxMV=tt) # Evaluate training data fit
4242
mod.plot(Xtest=Xtest, Ytest=Ytest, idxMV=tt) # Evaluate test data fit
43-
mod.sobol()
43+
mod.mvSobol()
4444
mod.plotSobol(idxMV=tt)
4545

4646
# All posterior predictive samples
@@ -65,7 +65,7 @@ def f(x): # Friedman function where first variable is the functional variable
6565
mod.traceplot()
6666
mod.plot() # Evaluate training data fit
6767
mod.plot(Xtest=Xtest, Ytest=Ytest) # Evaluate test data fit
68-
mod.sobol()
68+
mod.mvSobol()
6969
mod.plotSobol()
7070

7171
# All posterior predictive samples

mvBayes/mvBayes.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,7 +1101,7 @@ def plot(
11011101

11021102
# Plot Residuals on top of Y
11031103
fig.add_subplot(2, 2, 1)
1104-
mseOverall = self.getMSE(R)
1104+
mseOverall = self.getMSE(R) * Ytest.shape[1]
11051105
if not self.basisInfo.center and not self.basisInfo.scale:
11061106
legendLab = "Y"
11071107
else:
@@ -1118,7 +1118,7 @@ def plot(
11181118
plt.xscale(xscale)
11191119
plt.ylabel("Residuals")
11201120
plt.xlabel(xlabel)
1121-
plt.title(f"Overall MSE = {mseOverall:.4g}")
1121+
plt.title(f"Overall MSE = {mseOverall/Ytest.shape[1]:.4g}")
11221122
plt.legend()
11231123

11241124
# Plot each basis, scaled by residuals
@@ -1148,11 +1148,6 @@ def plot(
11481148
* self.basisInfo.Yscale
11491149
)
11501150
mseBasis[k] = np.mean(RbasisCoefs[:, k] ** 2)
1151-
1152-
basisScaled = (
1153-
np.outer(coefs[:, k], self.basisInfo.basis[k, :])
1154-
* self.basisInfo.Yscale
1155-
)
11561151
varBasis[k] = self.basisInfo.varExplained[k]*(Ytest.shape[0]-1)/(Ytest.shape[0])
11571152
mseOrder = np.argsort(mseBasis)[::-1]
11581153

@@ -1236,6 +1231,7 @@ def plot(
12361231
plt.close(fig)
12371232

12381233
return
1234+
12391235

12401236
def mvSobol(
12411237
self, totalSobol=True, idxSamples="final", nMC=None, showPlot=False, **kwargs

0 commit comments

Comments
 (0)