Emulating using CNNs

[1]:
import iris

from utils import get_bc_ppe_data

from esem import cnn_model
from esem.utils import get_random_params

import iris.quickplot as qplt
import matplotlib.pyplot as plt
%matplotlib inline

Read in the parameters and data

[2]:
ppe_params, ppe_aaod = get_bc_ppe_data()
[3]:
## Ensure the time dimension is last - this is treated as the color 'channel'
ppe_aaod.transpose((0,2,3,1))
[4]:
n_test = 5

X_test, X_train = ppe_params[:n_test], ppe_params[n_test:]
Y_test, Y_train = ppe_aaod[:n_test], ppe_aaod[n_test:]
[5]:
Y_train
[5]:
Absorption Optical Thickness - Total 550Nm (1) job latitude longitude time
Shape 34 96 192 12
Dimension coordinates
job x - - -
latitude - x - -
longitude - - x -
time - - - x
Attributes
CDI Climate Data Interface version 1.9.8 (https://mpimet.mpg.de/cdi)
CDO Climate Data Operators version 1.9.8 (https://mpimet.mpg.de/cdo)
Conventions CF-1.5
NCO netCDF Operators version 4.8.1 (Homepage = http://nco.sf.net, Code = h...
advection Lin & Rood
echam_version 6.3.02
frequency mon
grid_type gaussian
history Sat Nov 23 17:22:40 2019: cdo monavg BC_PPE_PD_AAOD_t.nc BC_PPE_PD_AAOD_monthly.nc
Sat...
institution MPIMET
jsbach_version 3.10
operating_system Linux 3.0.101-0.46.1_1.0502.8871-cray_ari_c x86_64
physics Modified ECMWF physics
radiation Using PSrad/RRTMG radiation
source ECHAM6.3
truncation 63
user_name user name not available
Cell methods
mean: time

Setup and run the models

[6]:
model = cnn_model(X_train, Y_train)
[7]:
model.train()
Epoch 1/100
4/4 [==============================] - 3s 71ms/step - loss: 1.1407 - val_loss: 0.4622
Epoch 2/100
4/4 [==============================] - 0s 19ms/step - loss: 1.1396 - val_loss: 0.4620
Epoch 3/100
4/4 [==============================] - 0s 17ms/step - loss: 1.1391 - val_loss: 0.4618
Epoch 4/100
4/4 [==============================] - 0s 19ms/step - loss: 1.1382 - val_loss: 0.4613
Epoch 5/100
4/4 [==============================] - ETA: 0s - loss: 1.048 - 0s 18ms/step - loss: 1.1360 - val_loss: 0.4577
Epoch 6/100
4/4 [==============================] - 0s 17ms/step - loss: 1.1256 - val_loss: 0.4536
Epoch 7/100
4/4 [==============================] - 0s 18ms/step - loss: 1.1114 - val_loss: 0.4446
Epoch 8/100
4/4 [==============================] - 0s 19ms/step - loss: 1.0957 - val_loss: 0.4372
Epoch 9/100
4/4 [==============================] - 0s 20ms/step - loss: 1.0746 - val_loss: 0.4246
Epoch 10/100
4/4 [==============================] - 0s 18ms/step - loss: 1.0569 - val_loss: 0.4152
Epoch 11/100
4/4 [==============================] - 0s 20ms/step - loss: 1.0326 - val_loss: 0.4073
Epoch 12/100
4/4 [==============================] - 0s 19ms/step - loss: 1.0055 - val_loss: 0.4009
Epoch 13/100
4/4 [==============================] - 0s 20ms/step - loss: 0.9815 - val_loss: 0.3892
Epoch 14/100
4/4 [==============================] - 0s 21ms/step - loss: 0.9567 - val_loss: 0.3797
Epoch 15/100
4/4 [==============================] - 0s 22ms/step - loss: 0.9285 - val_loss: 0.3679
Epoch 16/100
4/4 [==============================] - 0s 22ms/step - loss: 0.9015 - val_loss: 0.3472
Epoch 17/100
4/4 [==============================] - 0s 21ms/step - loss: 0.8915 - val_loss: 0.3343
Epoch 18/100
4/4 [==============================] - 0s 20ms/step - loss: 0.8643 - val_loss: 0.3255
Epoch 19/100
4/4 [==============================] - 0s 19ms/step - loss: 0.8362 - val_loss: 0.3176
Epoch 20/100
4/4 [==============================] - 0s 19ms/step - loss: 0.8201 - val_loss: 0.3009
Epoch 21/100
4/4 [==============================] - 0s 19ms/step - loss: 0.7953 - val_loss: 0.2862
Epoch 22/100
4/4 [==============================] - 0s 19ms/step - loss: 0.7777 - val_loss: 0.2778
Epoch 23/100
4/4 [==============================] - 0s 18ms/step - loss: 0.7578 - val_loss: 0.2747
Epoch 24/100
4/4 [==============================] - 0s 19ms/step - loss: 0.7465 - val_loss: 0.2592
Epoch 25/100
4/4 [==============================] - 0s 22ms/step - loss: 0.7223 - val_loss: 0.2570
Epoch 26/100
4/4 [==============================] - 0s 19ms/step - loss: 0.7036 - val_loss: 0.2337
Epoch 27/100
4/4 [==============================] - 0s 20ms/step - loss: 0.6804 - val_loss: 0.2253
Epoch 28/100
4/4 [==============================] - 0s 18ms/step - loss: 0.6666 - val_loss: 0.2308
Epoch 29/100
4/4 [==============================] - 0s 18ms/step - loss: 0.6502 - val_loss: 0.2109
Epoch 30/100
4/4 [==============================] - 0s 19ms/step - loss: 0.6315 - val_loss: 0.1978
Epoch 31/100
4/4 [==============================] - 0s 18ms/step - loss: 0.6138 - val_loss: 0.1819
Epoch 32/100
4/4 [==============================] - 0s 19ms/step - loss: 0.5970 - val_loss: 0.1738
Epoch 33/100
4/4 [==============================] - 0s 21ms/step - loss: 0.5801 - val_loss: 0.1640
Epoch 34/100
4/4 [==============================] - 0s 22ms/step - loss: 0.5622 - val_loss: 0.1643
Epoch 35/100
4/4 [==============================] - 0s 18ms/step - loss: 0.5502 - val_loss: 0.1508
Epoch 36/100
4/4 [==============================] - 0s 18ms/step - loss: 0.5405 - val_loss: 0.1625
Epoch 37/100
4/4 [==============================] - 0s 18ms/step - loss: 0.5311 - val_loss: 0.1517
Epoch 38/100
4/4 [==============================] - 0s 18ms/step - loss: 0.5191 - val_loss: 0.1331
Epoch 39/100
4/4 [==============================] - 0s 19ms/step - loss: 0.5063 - val_loss: 0.1264
Epoch 40/100
4/4 [==============================] - 0s 18ms/step - loss: 0.4976 - val_loss: 0.1217
Epoch 41/100
4/4 [==============================] - 0s 18ms/step - loss: 0.4865 - val_loss: 0.1159
Epoch 42/100
4/4 [==============================] - 0s 18ms/step - loss: 0.4715 - val_loss: 0.1112
Epoch 43/100
4/4 [==============================] - 0s 19ms/step - loss: 0.4604 - val_loss: 0.1131
Epoch 44/100
4/4 [==============================] - 0s 19ms/step - loss: 0.4498 - val_loss: 0.1002
Epoch 45/100
4/4 [==============================] - 0s 21ms/step - loss: 0.4404 - val_loss: 0.0969
Epoch 46/100
4/4 [==============================] - 0s 20ms/step - loss: 0.4320 - val_loss: 0.0938
Epoch 47/100
4/4 [==============================] - 0s 18ms/step - loss: 0.4242 - val_loss: 0.0905
Epoch 48/100
4/4 [==============================] - 0s 18ms/step - loss: 0.4173 - val_loss: 0.0874
Epoch 49/100
4/4 [==============================] - 0s 19ms/step - loss: 0.4105 - val_loss: 0.0903
Epoch 50/100
4/4 [==============================] - 0s 18ms/step - loss: 0.4032 - val_loss: 0.0825
Epoch 51/100
4/4 [==============================] - 0s 21ms/step - loss: 0.3931 - val_loss: 0.0872
Epoch 52/100
4/4 [==============================] - 0s 22ms/step - loss: 0.3891 - val_loss: 0.0839
Epoch 53/100
4/4 [==============================] - 0s 22ms/step - loss: 0.3813 - val_loss: 0.0768
Epoch 54/100
4/4 [==============================] - 0s 22ms/step - loss: 0.3767 - val_loss: 0.0764
Epoch 55/100
4/4 [==============================] - 0s 21ms/step - loss: 0.3720 - val_loss: 0.0776
Epoch 56/100
4/4 [==============================] - 0s 21ms/step - loss: 0.3671 - val_loss: 0.0723
Epoch 57/100
4/4 [==============================] - 0s 20ms/step - loss: 0.3601 - val_loss: 0.0719
Epoch 58/100
4/4 [==============================] - 0s 19ms/step - loss: 0.3570 - val_loss: 0.0708
Epoch 59/100
4/4 [==============================] - 0s 20ms/step - loss: 0.3562 - val_loss: 0.0704
Epoch 60/100
4/4 [==============================] - 0s 20ms/step - loss: 0.3519 - val_loss: 0.0699
Epoch 61/100
4/4 [==============================] - 0s 19ms/step - loss: 0.3488 - val_loss: 0.0689
Epoch 62/100
4/4 [==============================] - 0s 20ms/step - loss: 0.3444 - val_loss: 0.0727
Epoch 63/100
4/4 [==============================] - 0s 19ms/step - loss: 0.3427 - val_loss: 0.0795
Epoch 64/100
4/4 [==============================] - 0s 20ms/step - loss: 0.3420 - val_loss: 0.0691
Epoch 65/100
4/4 [==============================] - 0s 19ms/step - loss: 0.3360 - val_loss: 0.0678
Epoch 66/100
4/4 [==============================] - 0s 21ms/step - loss: 0.3336 - val_loss: 0.0664
Epoch 67/100
4/4 [==============================] - 0s 19ms/step - loss: 0.3303 - val_loss: 0.0689
Epoch 68/100
4/4 [==============================] - 0s 19ms/step - loss: 0.3297 - val_loss: 0.0669
Epoch 69/100
4/4 [==============================] - 0s 19ms/step - loss: 0.3246 - val_loss: 0.0659
Epoch 70/100
4/4 [==============================] - 0s 19ms/step - loss: 0.3246 - val_loss: 0.0682
Epoch 71/100
4/4 [==============================] - 0s 20ms/step - loss: 0.3217 - val_loss: 0.0660
Epoch 72/100
4/4 [==============================] - 0s 18ms/step - loss: 0.3196 - val_loss: 0.0656
Epoch 73/100
4/4 [==============================] - 0s 19ms/step - loss: 0.3201 - val_loss: 0.0667
Epoch 74/100
4/4 [==============================] - 0s 18ms/step - loss: 0.3166 - val_loss: 0.0653
Epoch 75/100
4/4 [==============================] - 0s 18ms/step - loss: 0.3151 - val_loss: 0.0666
Epoch 76/100
4/4 [==============================] - 0s 19ms/step - loss: 0.3137 - val_loss: 0.0657
Epoch 77/100
4/4 [==============================] - 0s 18ms/step - loss: 0.3143 - val_loss: 0.0648
Epoch 78/100
4/4 [==============================] - 0s 18ms/step - loss: 0.3112 - val_loss: 0.0647
Epoch 79/100
4/4 [==============================] - 0s 18ms/step - loss: 0.3084 - val_loss: 0.0674
Epoch 80/100
4/4 [==============================] - 0s 19ms/step - loss: 0.3087 - val_loss: 0.0650
Epoch 81/100
4/4 [==============================] - 0s 19ms/step - loss: 0.3056 - val_loss: 0.0709
Epoch 82/100
4/4 [==============================] - 0s 18ms/step - loss: 0.3082 - val_loss: 0.0651
Epoch 83/100
4/4 [==============================] - 0s 18ms/step - loss: 0.3048 - val_loss: 0.0644
Epoch 84/100
4/4 [==============================] - 0s 19ms/step - loss: 0.3036 - val_loss: 0.0669
Epoch 85/100
4/4 [==============================] - 0s 18ms/step - loss: 0.3032 - val_loss: 0.0662
Epoch 86/100
4/4 [==============================] - 0s 18ms/step - loss: 0.3030 - val_loss: 0.0672
Epoch 87/100
4/4 [==============================] - 0s 18ms/step - loss: 0.3016 - val_loss: 0.0691
Epoch 88/100
4/4 [==============================] - 0s 18ms/step - loss: 0.3014 - val_loss: 0.0706
Epoch 89/100
4/4 [==============================] - 0s 18ms/step - loss: 0.3006 - val_loss: 0.0648
Epoch 90/100
4/4 [==============================] - 0s 18ms/step - loss: 0.2988 - val_loss: 0.0670
Epoch 91/100
4/4 [==============================] - 0s 17ms/step - loss: 0.2985 - val_loss: 0.0670
Epoch 92/100
4/4 [==============================] - 0s 21ms/step - loss: 0.2977 - val_loss: 0.0635
Epoch 93/100
4/4 [==============================] - 0s 20ms/step - loss: 0.2967 - val_loss: 0.0638
Epoch 94/100
4/4 [==============================] - 0s 23ms/step - loss: 0.2969 - val_loss: 0.0654
Epoch 95/100
4/4 [==============================] - 0s 19ms/step - loss: 0.2954 - val_loss: 0.0648
Epoch 96/100
4/4 [==============================] - 0s 21ms/step - loss: 0.2973 - val_loss: 0.0638
Epoch 97/100
4/4 [==============================] - 0s 23ms/step - loss: 0.2949 - val_loss: 0.0668
Epoch 98/100
4/4 [==============================] - 0s 24ms/step - loss: 0.2945 - val_loss: 0.0641
Epoch 99/100
4/4 [==============================] - 0s 22ms/step - loss: 0.2936 - val_loss: 0.0648
Epoch 100/100
4/4 [==============================] - 0s 23ms/step - loss: 0.2932 - val_loss: 0.0651
[8]:
m, v = model.predict(X_test.to_numpy())
[9]:
## TODO: Tidy this up a bit
plt.figure(figsize=(12, 8))
plt.subplot(2,2,1)
qplt.pcolormesh(m[0].collapsed('time', iris.analysis.MEAN))
plt.gca().set_title('Predicted')
plt.gca().coastlines()

plt.subplot(2,2,2)
qplt.pcolormesh(Y_test[0].collapsed('time', iris.analysis.MEAN))
plt.gca().set_title('Test')
plt.gca().coastlines()

plt.subplot(2,2,3)
qplt.pcolormesh((m.collapsed(['sample', 'time'], iris.analysis.MEAN)-Y_test.collapsed(['job', 'time'], iris.analysis.MEAN)), cmap='RdBu_r', vmin=-0.01, vmax=0.01)
plt.gca().coastlines()
plt.gca().set_title('Difference')
C:\Users\duncan\miniconda3\envs\climatebench\lib\site-packages\iris\coords.py:1803: UserWarning: Coordinate 'longitude' is not bounded, guessing contiguous bounds.
  warnings.warn(
C:\Users\duncan\miniconda3\envs\climatebench\lib\site-packages\iris\coords.py:1803: UserWarning: Coordinate 'latitude' is not bounded, guessing contiguous bounds.
  warnings.warn(
C:\Users\duncan\miniconda3\envs\climatebench\lib\site-packages\iris\coords.py:1803: UserWarning: Coordinate 'longitude' is not bounded, guessing contiguous bounds.
  warnings.warn(
C:\Users\duncan\miniconda3\envs\climatebench\lib\site-packages\iris\coords.py:1803: UserWarning: Coordinate 'latitude' is not bounded, guessing contiguous bounds.
  warnings.warn(
C:\Users\duncan\miniconda3\envs\climatebench\lib\site-packages\iris\coords.py:1979: UserWarning: Collapsing a non-contiguous coordinate. Metadata may not be fully descriptive for 'sample'.
  warnings.warn(msg.format(self.name()))
C:\Users\duncan\miniconda3\envs\climatebench\lib\site-packages\iris\coords.py:1979: UserWarning: Collapsing a non-contiguous coordinate. Metadata may not be fully descriptive for 'job'.
  warnings.warn(msg.format(self.name()))
C:\Users\duncan\miniconda3\envs\climatebench\lib\site-packages\iris\coords.py:1803: UserWarning: Coordinate 'longitude' is not bounded, guessing contiguous bounds.
  warnings.warn(
C:\Users\duncan\miniconda3\envs\climatebench\lib\site-packages\iris\coords.py:1803: UserWarning: Coordinate 'latitude' is not bounded, guessing contiguous bounds.
  warnings.warn(
[9]:
Text(0.5, 1.0, 'Difference')
../_images/examples_Emulating_using_ConvNets_11_2.png
[10]:
m, sd = model.batch_stats(get_random_params(3, int(1e5)), batch_size=1000)
<tqdm.auto.tqdm object at 0x0000019580079D60>
[ ]: