Create paper emulation figure

[1]:
import os
## Ignore my broken HDF5 install...
os.putenv("HDF5_DISABLE_VERSION_CHECK", '1')
[2]:
import iris

from utils import get_bc_ppe_data

from GCEm import cnn_model, gp_model
from GCEm.utils import get_random_params

import iris.quickplot as qplt
import iris.analysis.maths as imath
import matplotlib.pyplot as plt
%matplotlib inline

Read in the parameters and data

[3]:
ppe_params, ppe_aaod = get_bc_ppe_data()
/Users/watson-parris/miniconda3/envs/gcem/lib/python3.8/site-packages/iris/__init__.py:249: IrisDeprecation: setting the 'Future' property 'netcdf_promote' is deprecated and will be removed in a future release. Please remove code that sets this property.
  warn_deprecated(msg.format(name))
/Users/watson-parris/miniconda3/envs/gcem/lib/python3.8/site-packages/iris/__init__.py:249: IrisDeprecation: setting the 'Future' property 'netcdf_promote' is deprecated and will be removed in a future release. Please remove code that sets this property.
  warn_deprecated(msg.format(name))
[4]:
## Ensure thWetdepnumbertime dimension is last - this is treated as the color 'channel'
## ppe_aaod.transpose((0,2,3,1))
ppe_aaod = ppe_aaod.collapsed('time')[0]
WARNING:root:Creating guessed bounds as none exist in file
WARNING:root:Creating guessed bounds as none exist in file
WARNING:root:Creating guessed bounds as none exist in file
[5]:
n_test = 5

X_test, X_train = ppe_params[:n_test], ppe_params[n_test:]
Y_test, Y_train = ppe_aaod[:n_test], ppe_aaod[n_test:]
[6]:
Y_train
[6]:
Absorption Optical Thickness - Total 550Nm (1) job latitude longitude
Shape 34 96 192
Dimension coordinates
job x - -
latitude - x -
longitude - - x
Scalar coordinates
time 2017-07-02 10:30:00, bound=(2017-01-01 00:00:00, 2017-12-31 21:00:00)
Attributes
CDI Climate Data Interface version 1.9.8 (https://mpimet.mpg.de/cdi)
CDO Climate Data Operators version 1.9.8 (https://mpimet.mpg.de/cdo)
Conventions CF-1.5
NCO netCDF Operators version 4.8.1 (Homepage = http://nco.sf.net, Code = h...
advection Lin & Rood
echam_version 6.3.02
frequency mon
grid_type gaussian
history Sat Nov 23 17:22:40 2019: cdo monavg BC_PPE_PD_AAOD_t.nc BC_PPE_PD_AAOD_monthly.nc
Sat...
institution MPIMET
jsbach_version 3.10
operating_system Linux 3.0.101-0.46.1_1.0502.8871-cray_ari_c x86_64
physics Modified ECMWF physics
radiation Using PSrad/RRTMG radiation
source ECHAM6.3
truncation 63
user_name user name not available
Cell methods
mean time
mean time

Setup and run the models

[7]:
nn_model = cnn_model(X_train, Y_train)
[8]:
nn_model.model.model.summary()
Model: "decoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         [(None, 3)]               0
_________________________________________________________________
dense (Dense)                (None, 221184)            884736
_________________________________________________________________
reshape (Reshape)            (None, 96, 192, 12)       0
_________________________________________________________________
conv2d_transpose (Conv2DTran (None, 96, 192, 12)       2172
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 96, 192, 1)        181
=================================================================
Total params: 887,089
Trainable params: 887,089
Non-trainable params: 0
_________________________________________________________________
[9]:
nn_model.train()
Epoch 1/100
4/4 [==============================] - 1s 284ms/step - loss: 0.8766 - val_loss: 0.4800
Epoch 2/100
4/4 [==============================] - 0s 115ms/step - loss: 1.0869 - val_loss: 0.4778
Epoch 3/100
4/4 [==============================] - 0s 100ms/step - loss: 1.1900 - val_loss: 0.4759
Epoch 4/100
4/4 [==============================] - 0s 106ms/step - loss: 1.0656 - val_loss: 0.4741
Epoch 5/100
4/4 [==============================] - 0s 99ms/step - loss: 1.2058 - val_loss: 0.4699
Epoch 6/100
4/4 [==============================] - 0s 112ms/step - loss: 1.2211 - val_loss: 0.4672
Epoch 7/100
4/4 [==============================] - 0s 105ms/step - loss: 1.0138 - val_loss: 0.4624
Epoch 8/100
4/4 [==============================] - 0s 101ms/step - loss: 0.9261 - val_loss: 0.4645
Epoch 9/100
4/4 [==============================] - 0s 112ms/step - loss: 0.9354 - val_loss: 0.4569
Epoch 10/100
4/4 [==============================] - 0s 100ms/step - loss: 1.1767 - val_loss: 0.4494
Epoch 11/100
4/4 [==============================] - 0s 101ms/step - loss: 1.1077 - val_loss: 0.4310
Epoch 12/100
4/4 [==============================] - 0s 101ms/step - loss: 0.8575 - val_loss: 0.4343
Epoch 13/100
4/4 [==============================] - 0s 101ms/step - loss: 0.8162 - val_loss: 0.4476
Epoch 14/100
4/4 [==============================] - 0s 102ms/step - loss: 0.8983 - val_loss: 0.4181
Epoch 15/100
4/4 [==============================] - 1s 104ms/step - loss: 0.8968 - val_loss: 0.4003
Epoch 16/100
4/4 [==============================] - 0s 100ms/step - loss: 0.8450 - val_loss: 0.3989
Epoch 17/100
4/4 [==============================] - 0s 103ms/step - loss: 0.7290 - val_loss: 0.4107
Epoch 18/100
4/4 [==============================] - 0s 100ms/step - loss: 0.8785 - val_loss: 0.3785
Epoch 19/100
4/4 [==============================] - 0s 106ms/step - loss: 0.7236 - val_loss: 0.3918
Epoch 20/100
4/4 [==============================] - 0s 103ms/step - loss: 0.8572 - val_loss: 0.3857
Epoch 21/100
4/4 [==============================] - 0s 102ms/step - loss: 0.8973 - val_loss: 0.3543
Epoch 22/100
4/4 [==============================] - 0s 101ms/step - loss: 0.9261 - val_loss: 0.3282
Epoch 23/100
4/4 [==============================] - 0s 102ms/step - loss: 0.5686 - val_loss: 0.3123
Epoch 24/100
4/4 [==============================] - 0s 100ms/step - loss: 0.9107 - val_loss: 0.3041
Epoch 25/100
4/4 [==============================] - 0s 101ms/step - loss: 0.7340 - val_loss: 0.2931
Epoch 26/100
4/4 [==============================] - 0s 100ms/step - loss: 0.7675 - val_loss: 0.2825
Epoch 27/100
4/4 [==============================] - 0s 105ms/step - loss: 0.5598 - val_loss: 0.2795
Epoch 28/100
4/4 [==============================] - 0s 101ms/step - loss: 0.6441 - val_loss: 0.2640
Epoch 29/100
4/4 [==============================] - 0s 100ms/step - loss: 0.8158 - val_loss: 0.2579
Epoch 30/100
4/4 [==============================] - 0s 102ms/step - loss: 0.7065 - val_loss: 0.2447
Epoch 31/100
4/4 [==============================] - 0s 105ms/step - loss: 0.6948 - val_loss: 0.2351
Epoch 32/100
4/4 [==============================] - 0s 125ms/step - loss: 0.7187 - val_loss: 0.2226
Epoch 33/100
4/4 [==============================] - 0s 122ms/step - loss: 0.5605 - val_loss: 0.2154
Epoch 34/100
4/4 [==============================] - 0s 104ms/step - loss: 0.5716 - val_loss: 0.2117
Epoch 35/100
4/4 [==============================] - 0s 107ms/step - loss: 0.6035 - val_loss: 0.1991
Epoch 36/100
4/4 [==============================] - 0s 107ms/step - loss: 0.6084 - val_loss: 0.1892
Epoch 37/100
4/4 [==============================] - 0s 103ms/step - loss: 0.6416 - val_loss: 0.1786
Epoch 38/100
4/4 [==============================] - 0s 100ms/step - loss: 0.4608 - val_loss: 0.1749
Epoch 39/100
4/4 [==============================] - 0s 100ms/step - loss: 0.3582 - val_loss: 0.1674
Epoch 40/100
4/4 [==============================] - 0s 113ms/step - loss: 0.4616 - val_loss: 0.1607
Epoch 41/100
4/4 [==============================] - 0s 117ms/step - loss: 0.5972 - val_loss: 0.1558
Epoch 42/100
4/4 [==============================] - 0s 121ms/step - loss: 0.3961 - val_loss: 0.1471
Epoch 43/100
4/4 [==============================] - 0s 111ms/step - loss: 0.5399 - val_loss: 0.1404
Epoch 44/100
4/4 [==============================] - 0s 100ms/step - loss: 0.4118 - val_loss: 0.1385
Epoch 45/100
4/4 [==============================] - 0s 126ms/step - loss: 0.3920 - val_loss: 0.1316
Epoch 46/100
4/4 [==============================] - 0s 121ms/step - loss: 0.4126 - val_loss: 0.1237
Epoch 47/100
4/4 [==============================] - 1s 247ms/step - loss: 0.3554 - val_loss: 0.1202
Epoch 48/100
4/4 [==============================] - 0s 129ms/step - loss: 0.3429 - val_loss: 0.1145
Epoch 49/100
4/4 [==============================] - 1s 128ms/step - loss: 0.3157 - val_loss: 0.1075
Epoch 50/100
4/4 [==============================] - 1s 117ms/step - loss: 0.4217 - val_loss: 0.1051
Epoch 51/100
4/4 [==============================] - 0s 112ms/step - loss: 0.4561 - val_loss: 0.1078
Epoch 52/100
4/4 [==============================] - 0s 121ms/step - loss: 0.4475 - val_loss: 0.0932
Epoch 53/100
4/4 [==============================] - 0s 123ms/step - loss: 0.2861 - val_loss: 0.0896
Epoch 54/100
4/4 [==============================] - 0s 115ms/step - loss: 0.3469 - val_loss: 0.0886
Epoch 55/100
4/4 [==============================] - 0s 126ms/step - loss: 0.3332 - val_loss: 0.0920
Epoch 56/100
4/4 [==============================] - 0s 117ms/step - loss: 0.3290 - val_loss: 0.0793
Epoch 57/100
4/4 [==============================] - 1s 136ms/step - loss: 0.3292 - val_loss: 0.0792
Epoch 58/100
4/4 [==============================] - 0s 104ms/step - loss: 0.2757 - val_loss: 0.0760
Epoch 59/100
4/4 [==============================] - 0s 110ms/step - loss: 0.2637 - val_loss: 0.0720
Epoch 60/100
4/4 [==============================] - 0s 123ms/step - loss: 0.4001 - val_loss: 0.0729
Epoch 61/100
4/4 [==============================] - 0s 110ms/step - loss: 0.4507 - val_loss: 0.0666
Epoch 62/100
4/4 [==============================] - 1s 131ms/step - loss: 0.3879 - val_loss: 0.0695
Epoch 63/100
4/4 [==============================] - 0s 106ms/step - loss: 0.2597 - val_loss: 0.0610
Epoch 64/100
4/4 [==============================] - 0s 111ms/step - loss: 0.3288 - val_loss: 0.0725
Epoch 65/100
4/4 [==============================] - 0s 102ms/step - loss: 0.2841 - val_loss: 0.0575
Epoch 66/100
4/4 [==============================] - 0s 101ms/step - loss: 0.2730 - val_loss: 0.0544
Epoch 67/100
4/4 [==============================] - 0s 101ms/step - loss: 0.2694 - val_loss: 0.0517
Epoch 68/100
4/4 [==============================] - 0s 110ms/step - loss: 0.3846 - val_loss: 0.0573
Epoch 69/100
4/4 [==============================] - 0s 100ms/step - loss: 0.3429 - val_loss: 0.0528
Epoch 70/100
4/4 [==============================] - 0s 106ms/step - loss: 0.2360 - val_loss: 0.0478
Epoch 71/100
4/4 [==============================] - 0s 100ms/step - loss: 0.3651 - val_loss: 0.0486
Epoch 72/100
4/4 [==============================] - 0s 107ms/step - loss: 0.2351 - val_loss: 0.0469
Epoch 73/100
4/4 [==============================] - 0s 101ms/step - loss: 0.2993 - val_loss: 0.0476
Epoch 74/100
4/4 [==============================] - 0s 105ms/step - loss: 0.2739 - val_loss: 0.0523
Epoch 75/100
4/4 [==============================] - 0s 103ms/step - loss: 0.3909 - val_loss: 0.0438
Epoch 76/100
4/4 [==============================] - 0s 109ms/step - loss: 0.2526 - val_loss: 0.0423
Epoch 77/100
4/4 [==============================] - 0s 103ms/step - loss: 0.2566 - val_loss: 0.0485
Epoch 78/100
4/4 [==============================] - 0s 101ms/step - loss: 0.2969 - val_loss: 0.0411
Epoch 79/100
4/4 [==============================] - 0s 99ms/step - loss: 0.2602 - val_loss: 0.0413
Epoch 80/100
4/4 [==============================] - 0s 126ms/step - loss: 0.1863 - val_loss: 0.0382
Epoch 81/100
4/4 [==============================] - 0s 101ms/step - loss: 0.3251 - val_loss: 0.0394
Epoch 82/100
4/4 [==============================] - 0s 109ms/step - loss: 0.3038 - val_loss: 0.0376
Epoch 83/100
4/4 [==============================] - 0s 105ms/step - loss: 0.3729 - val_loss: 0.0405
Epoch 84/100
4/4 [==============================] - 0s 107ms/step - loss: 0.2542 - val_loss: 0.0389
Epoch 85/100
4/4 [==============================] - 0s 109ms/step - loss: 0.2379 - val_loss: 0.0403
Epoch 86/100
4/4 [==============================] - 0s 117ms/step - loss: 0.2620 - val_loss: 0.0350
Epoch 87/100
4/4 [==============================] - 0s 108ms/step - loss: 0.2459 - val_loss: 0.0356
Epoch 88/100
4/4 [==============================] - 0s 107ms/step - loss: 0.1902 - val_loss: 0.0368
Epoch 89/100
4/4 [==============================] - 1s 134ms/step - loss: 0.3153 - val_loss: 0.0362
Epoch 90/100
4/4 [==============================] - 0s 116ms/step - loss: 0.3270 - val_loss: 0.0371
Epoch 91/100
4/4 [==============================] - 1s 123ms/step - loss: 0.3058 - val_loss: 0.0393
Epoch 92/100
4/4 [==============================] - 0s 115ms/step - loss: 0.3535 - val_loss: 0.0359
Epoch 93/100
4/4 [==============================] - 0s 102ms/step - loss: 0.3056 - val_loss: 0.0379
Epoch 94/100
4/4 [==============================] - 0s 106ms/step - loss: 0.3331 - val_loss: 0.0354
Epoch 95/100
4/4 [==============================] - 0s 99ms/step - loss: 0.2074 - val_loss: 0.0364
Epoch 96/100
4/4 [==============================] - 0s 99ms/step - loss: 0.1920 - val_loss: 0.0364
Epoch 97/100
4/4 [==============================] - 0s 101ms/step - loss: 0.3333 - val_loss: 0.0329
Epoch 98/100
4/4 [==============================] - 0s 99ms/step - loss: 0.2105 - val_loss: 0.0340
Epoch 99/100
4/4 [==============================] - 0s 105ms/step - loss: 0.3683 - val_loss: 0.0355
Epoch 100/100
4/4 [==============================] - 0s 100ms/step - loss: 0.3073 - val_loss: 0.0361
[10]:
## Linear model:  0.3566 - val_loss: 0.0867
[11]:
nn_prediction, _ = nn_model.predict(X_test.values)
[12]:
gp_model_ = gp_model(X_train, Y_train, kernel=['Bias', 'Linear'])
gp_model_.train()
[13]:
gp_prediction, _ = gp_model_.predict(X_test.values)
[14]:
import matplotlib
import cartopy.crs as ccrs
import iris.plot as iplt


plt.figure(figsize=(30, 10))
matplotlib.rcParams['font.size'] = 24

plt.subplot(2,3,1, projection=ccrs.Mollweide())
plt.annotate("(a)", (0.,1.), xycoords='axes fraction')
iplt.pcolormesh(imath.log10(Y_test[0]), vmin=-4, vmax=-1)
plt.gca().set_title('Truth')
plt.gca().coastlines()

plt.subplot(2,3,2, projection=ccrs.Mollweide())
plt.annotate("(b)", (0.,1.), xycoords='axes fraction')
iplt.pcolormesh(imath.log10(gp_prediction[0]), vmin=-4, vmax=-1)
plt.gca().set_title('GP')
plt.gca().coastlines()

plt.subplot(2,3,3, projection=ccrs.Mollweide())
plt.annotate("(c)", (0.,1.), xycoords='axes fraction')
im=iplt.pcolormesh(imath.log10(nn_prediction[0]), vmin=-4, vmax=-1)
plt.gca().set_title('CNN')
plt.colorbar(im, fraction=0.046, pad=0.04, label='log(AAOD)')
plt.gca().coastlines()

plt.subplot(2,3,5, projection=ccrs.Mollweide())
plt.annotate("(d)", (0.,1.), xycoords='axes fraction')
iplt.pcolormesh((gp_prediction.collapsed(['sample'], iris.analysis.MEAN)-Y_test.collapsed(['job'], iris.analysis.MEAN)), cmap='RdBu_r', vmin=-0.001, vmax=0.001)
plt.gca().coastlines()
plt.gca().set_title('Difference')

plt.subplot(2,3,6, projection=ccrs.Mollweide())
plt.annotate("(e)", (0.,1.), xycoords='axes fraction')
im=iplt.pcolormesh((nn_prediction.collapsed(['sample'], iris.analysis.MEAN)-Y_test.collapsed(['job'], iris.analysis.MEAN)), cmap='RdBu_r', vmin=-1e-3, vmax=1e-3)
cb = plt.colorbar(im, fraction=0.046, pad=0.04)
cb.ax.set_yticklabels(["{:.2e}".format(i) for i in cb.get_ticks()]) ## set ticks of your format
plt.gca().coastlines()
plt.gca().set_title('Difference')

plt.savefig('BCPPE_emulator_paper.png', transparent=True)
/Users/watson-parris/miniconda3/envs/gcem/lib/python3.8/site-packages/iris/coords.py:1410: UserWarning: Collapsing a non-contiguous coordinate. Metadata may not be fully descriptive for 'sample'.
  warnings.warn(msg.format(self.name()))
/Users/watson-parris/miniconda3/envs/gcem/lib/python3.8/site-packages/iris/coords.py:1410: UserWarning: Collapsing a non-contiguous coordinate. Metadata may not be fully descriptive for 'sample'.
  warnings.warn(msg.format(self.name()))
../_images/examples_Create_GP_and_CNN_emulation_figure_16_1.png
[15]:

COLOR = 'white'
matplotlib.rcParams['text.color'] = COLOR
matplotlib.rcParams['axes.labelcolor'] = COLOR
matplotlib.rcParams['xtick.color'] = COLOR
matplotlib.rcParams['ytick.color'] = COLOR
matplotlib.rcParams['font.size'] = 20

plt.figure(figsize=(30, 10))

plt.subplot(2,3,1, projection=ccrs.Mollweide())
plt.annotate("(a)", (0.,1.), xycoords='axes fraction')
iplt.pcolormesh(imath.log10(Y_test[0]), vmin=-4, vmax=-1)
plt.gca().set_title('Truth')
plt.gca().coastlines()

plt.subplot(2,3,2, projection=ccrs.Mollweide())
plt.annotate("(b)", (0.,1.), xycoords='axes fraction')
iplt.pcolormesh(imath.log10(gp_prediction[0]), vmin=-4, vmax=-1)
plt.gca().set_title('GP')
plt.gca().coastlines()

plt.subplot(2,3,3, projection=ccrs.Mollweide())
plt.annotate("(c)", (0.,1.), xycoords='axes fraction')
im=iplt.pcolormesh(imath.log10(nn_prediction[0]), vmin=-4, vmax=-1)
plt.gca().set_title('CNN')
plt.colorbar(im, fraction=0.046, pad=0.04, label='log(AAOD)')
plt.gca().coastlines()

plt.subplot(2,3,5, projection=ccrs.Mollweide())
plt.annotate("(d)", (0.,1.), xycoords='axes fraction')
iplt.pcolormesh((gp_prediction.collapsed(['sample'], iris.analysis.MEAN)-Y_test.collapsed(['job'], iris.analysis.MEAN)), cmap='RdBu_r', vmin=-0.001, vmax=0.001)
plt.gca().coastlines()
plt.gca().set_title('Difference')

plt.subplot(2,3,6, projection=ccrs.Mollweide())
plt.annotate("(e)", (0.,1.), xycoords='axes fraction')
im=iplt.pcolormesh((nn_prediction.collapsed(['sample'], iris.analysis.MEAN)-Y_test.collapsed(['job'], iris.analysis.MEAN)), cmap='RdBu_r', vmin=-1e-3, vmax=1e-3)
cb=plt.colorbar(im, fraction=0.046, pad=0.04)
cb.ax.set_yticklabels(["{:.2e}".format(i) for i in cb.get_ticks()]) ## set ticks of your format
plt.gca().coastlines()
plt.gca().set_title('Difference')

plt.savefig('BCPPE_emulator_talk.png', transparent=True)
/Users/watson-parris/miniconda3/envs/gcem/lib/python3.8/site-packages/iris/coords.py:1410: UserWarning: Collapsing a non-contiguous coordinate. Metadata may not be fully descriptive for 'sample'.
  warnings.warn(msg.format(self.name()))
/Users/watson-parris/miniconda3/envs/gcem/lib/python3.8/site-packages/iris/coords.py:1410: UserWarning: Collapsing a non-contiguous coordinate. Metadata may not be fully descriptive for 'sample'.
  warnings.warn(msg.format(self.name()))
../_images/examples_Create_GP_and_CNN_emulation_figure_17_1.png
[ ]: