Emulating using CNNs

[2]:
import os
## Ignore my broken HDF5 install...
os.putenv("HDF5_DISABLE_VERSION_CHECK", '1')
[3]:
import iris

from utils import get_bc_ppe_data

from esem import cnn_model
from esem.utils import get_random_params

import iris.quickplot as qplt
import matplotlib.pyplot as plt
%matplotlib inline
C:\Users\duncan\miniconda3\envs\gcem_dev\lib\site-packages\h5py\__init__.py:40: UserWarning: h5py is running against HDF5 1.10.6 when it was built against 1.10.5, this may cause problems
  '{0}.{1}.{2}'.format(*version.hdf5_built_version_tuple)

Read in the parameters and data

[4]:
ppe_params, ppe_aaod = get_bc_ppe_data()
C:\Users\duncan\miniconda3\envs\gcem_dev\lib\site-packages\iris\__init__.py:249: IrisDeprecation: setting the 'Future' property 'netcdf_promote' is deprecated and will be removed in a future release. Please remove code that sets this property.
  warn_deprecated(msg.format(name))
C:\Users\duncan\miniconda3\envs\gcem_dev\lib\site-packages\iris\__init__.py:249: IrisDeprecation: setting the 'Future' property 'netcdf_promote' is deprecated and will be removed in a future release. Please remove code that sets this property.
  warn_deprecated(msg.format(name))
[5]:
## Ensure the time dimension is last - this is treated as the color 'channel'
ppe_aaod.transpose((0,2,3,1))
[6]:
n_test = 5

X_test, X_train = ppe_params[:n_test], ppe_params[n_test:]
Y_test, Y_train = ppe_aaod[:n_test], ppe_aaod[n_test:]
[7]:
Y_train
[7]:
Absorption Optical Thickness - Total 550Nm (1) job latitude longitude time
Shape 34 96 192 12
Dimension coordinates
job x - - -
latitude - x - -
longitude - - x -
time - - - x
Attributes
CDI Climate Data Interface version 1.9.8 (https://mpimet.mpg.de/cdi)
CDO Climate Data Operators version 1.9.8 (https://mpimet.mpg.de/cdo)
Conventions CF-1.5
NCO netCDF Operators version 4.8.1 (Homepage = http://nco.sf.net, Code = h...
advection Lin & Rood
echam_version 6.3.02
frequency mon
grid_type gaussian
history Sat Nov 23 17:22:40 2019: cdo monavg BC_PPE_PD_AAOD_t.nc BC_PPE_PD_AAOD_monthly.nc
Sat...
institution MPIMET
jsbach_version 3.10
operating_system Linux 3.0.101-0.46.1_1.0502.8871-cray_ari_c x86_64
physics Modified ECMWF physics
radiation Using PSrad/RRTMG radiation
source ECHAM6.3
truncation 63
user_name user name not available
Cell methods
mean time

Setup and run the models

[8]:
model = cnn_model(X_train, Y_train)
[9]:
model.train()
Epoch 1/100
4/4 [==============================] - 2s 329ms/step - loss: 1.2647 - val_loss: 0.4618
Epoch 2/100
4/4 [==============================] - 1s 172ms/step - loss: 1.1055 - val_loss: 0.4615
Epoch 3/100
4/4 [==============================] - 1s 168ms/step - loss: 1.1964 - val_loss: 0.4610
Epoch 4/100
4/4 [==============================] - 1s 171ms/step - loss: 1.0877 - val_loss: 0.4589
Epoch 5/100
4/4 [==============================] - 1s 172ms/step - loss: 1.3223 - val_loss: 0.4561
Epoch 6/100
4/4 [==============================] - 1s 170ms/step - loss: 1.0132 - val_loss: 0.4535
Epoch 7/100
4/4 [==============================] - 1s 169ms/step - loss: 1.1708 - val_loss: 0.4544
Epoch 8/100
4/4 [==============================] - 1s 177ms/step - loss: 1.2777 - val_loss: 0.4419
Epoch 9/100
4/4 [==============================] - 1s 196ms/step - loss: 1.0134 - val_loss: 0.4406
Epoch 10/100
4/4 [==============================] - 1s 197ms/step - loss: 1.0496 - val_loss: 0.4189
Epoch 11/100
4/4 [==============================] - 1s 179ms/step - loss: 0.8041 - val_loss: 0.4329
Epoch 12/100
4/4 [==============================] - 1s 170ms/step - loss: 1.1344 - val_loss: 0.4080
Epoch 13/100
4/4 [==============================] - 1s 173ms/step - loss: 0.9210 - val_loss: 0.3832
Epoch 14/100
4/4 [==============================] - 1s 171ms/step - loss: 0.8290 - val_loss: 0.3700
Epoch 15/100
4/4 [==============================] - 1s 179ms/step - loss: 0.9029 - val_loss: 0.3750
Epoch 16/100
4/4 [==============================] - 1s 198ms/step - loss: 0.8040 - val_loss: 0.3580
Epoch 17/100
4/4 [==============================] - 1s 169ms/step - loss: 0.9728 - val_loss: 0.3320
Epoch 18/100
4/4 [==============================] - 1s 170ms/step - loss: 0.8510 - val_loss: 0.3260
Epoch 19/100
4/4 [==============================] - 1s 171ms/step - loss: 0.6017 - val_loss: 0.3425
Epoch 20/100
4/4 [==============================] - 1s 170ms/step - loss: 0.5955 - val_loss: 0.3306
Epoch 21/100
4/4 [==============================] - 1s 174ms/step - loss: 0.7674 - val_loss: 0.2911
Epoch 22/100
4/4 [==============================] - 1s 168ms/step - loss: 0.6860 - val_loss: 0.2662
Epoch 23/100
4/4 [==============================] - 1s 171ms/step - loss: 0.7090 - val_loss: 0.2540
Epoch 24/100
4/4 [==============================] - 1s 173ms/step - loss: 0.6631 - val_loss: 0.2368
Epoch 25/100
4/4 [==============================] - 1s 171ms/step - loss: 0.6363 - val_loss: 0.2402
Epoch 26/100
4/4 [==============================] - 1s 170ms/step - loss: 0.4646 - val_loss: 0.2249
Epoch 27/100
4/4 [==============================] - 1s 172ms/step - loss: 0.6093 - val_loss: 0.2016
Epoch 28/100
4/4 [==============================] - 1s 171ms/step - loss: 0.5923 - val_loss: 0.1915
Epoch 29/100
4/4 [==============================] - 1s 168ms/step - loss: 0.4313 - val_loss: 0.1813
Epoch 30/100
4/4 [==============================] - 1s 173ms/step - loss: 0.4832 - val_loss: 0.1717
Epoch 31/100
4/4 [==============================] - 1s 176ms/step - loss: 0.5329 - val_loss: 0.1609
Epoch 32/100
4/4 [==============================] - 1s 172ms/step - loss: 0.4968 - val_loss: 0.1572
Epoch 33/100
4/4 [==============================] - 1s 174ms/step - loss: 0.5791 - val_loss: 0.1503
Epoch 34/100
4/4 [==============================] - 1s 173ms/step - loss: 0.4020 - val_loss: 0.1417
Epoch 35/100
4/4 [==============================] - 1s 170ms/step - loss: 0.5336 - val_loss: 0.1327
Epoch 36/100
4/4 [==============================] - 1s 172ms/step - loss: 0.4808 - val_loss: 0.1267
Epoch 37/100
4/4 [==============================] - 1s 171ms/step - loss: 0.4852 - val_loss: 0.1332
Epoch 38/100
4/4 [==============================] - 1s 176ms/step - loss: 0.4243 - val_loss: 0.1177
Epoch 39/100
4/4 [==============================] - 1s 172ms/step - loss: 0.3802 - val_loss: 0.1113
Epoch 40/100
4/4 [==============================] - 1s 173ms/step - loss: 0.5062 - val_loss: 0.1103
Epoch 41/100
4/4 [==============================] - 1s 172ms/step - loss: 0.4019 - val_loss: 0.1086
Epoch 42/100
4/4 [==============================] - 1s 175ms/step - loss: 0.4136 - val_loss: 0.1003
Epoch 43/100
4/4 [==============================] - 1s 171ms/step - loss: 0.3220 - val_loss: 0.0984
Epoch 44/100
4/4 [==============================] - 1s 172ms/step - loss: 0.3285 - val_loss: 0.0946
Epoch 45/100
4/4 [==============================] - 1s 173ms/step - loss: 0.3216 - val_loss: 0.0845
Epoch 46/100
4/4 [==============================] - 1s 180ms/step - loss: 0.2456 - val_loss: 0.0815
Epoch 47/100
4/4 [==============================] - 1s 172ms/step - loss: 0.3025 - val_loss: 0.0778
Epoch 48/100
4/4 [==============================] - 1s 175ms/step - loss: 0.2722 - val_loss: 0.0804
Epoch 49/100
4/4 [==============================] - 1s 175ms/step - loss: 0.3915 - val_loss: 0.0739
Epoch 50/100
4/4 [==============================] - 1s 182ms/step - loss: 0.4554 - val_loss: 0.0749
Epoch 51/100
4/4 [==============================] - 1s 174ms/step - loss: 0.2702 - val_loss: 0.0712
Epoch 52/100
4/4 [==============================] - 1s 169ms/step - loss: 0.3218 - val_loss: 0.0702
Epoch 53/100
4/4 [==============================] - 1s 169ms/step - loss: 0.3559 - val_loss: 0.0676
Epoch 54/100
4/4 [==============================] - 1s 188ms/step - loss: 0.3849 - val_loss: 0.0663
Epoch 55/100
4/4 [==============================] - 1s 195ms/step - loss: 0.3528 - val_loss: 0.0663
Epoch 56/100
4/4 [==============================] - 1s 169ms/step - loss: 0.4359 - val_loss: 0.0661
Epoch 57/100
4/4 [==============================] - 1s 175ms/step - loss: 0.4033 - val_loss: 0.0661
Epoch 58/100
4/4 [==============================] - 1s 172ms/step - loss: 0.3583 - val_loss: 0.0648
Epoch 59/100
4/4 [==============================] - 1s 189ms/step - loss: 0.2553 - val_loss: 0.0664
Epoch 60/100
4/4 [==============================] - 1s 176ms/step - loss: 0.2613 - val_loss: 0.0655
Epoch 61/100
4/4 [==============================] - 1s 198ms/step - loss: 0.3042 - val_loss: 0.0647
Epoch 62/100
4/4 [==============================] - 1s 198ms/step - loss: 0.3289 - val_loss: 0.0629
Epoch 63/100
4/4 [==============================] - 1s 186ms/step - loss: 0.3522 - val_loss: 0.0680
Epoch 64/100
4/4 [==============================] - 1s 174ms/step - loss: 0.3994 - val_loss: 0.0612
Epoch 65/100
4/4 [==============================] - 1s 187ms/step - loss: 0.3320 - val_loss: 0.0731
Epoch 66/100
4/4 [==============================] - 1s 173ms/step - loss: 0.2637 - val_loss: 0.0632
Epoch 67/100
4/4 [==============================] - 1s 183ms/step - loss: 0.2442 - val_loss: 0.0664
Epoch 68/100
4/4 [==============================] - 1s 216ms/step - loss: 0.4290 - val_loss: 0.0621
Epoch 69/100
4/4 [==============================] - 1s 186ms/step - loss: 0.3970 - val_loss: 0.0641
Epoch 70/100
4/4 [==============================] - 1s 180ms/step - loss: 0.2437 - val_loss: 0.0613
Epoch 71/100
4/4 [==============================] - 1s 182ms/step - loss: 0.3692 - val_loss: 0.0647
Epoch 72/100
4/4 [==============================] - 1s 173ms/step - loss: 0.3234 - val_loss: 0.0618
Epoch 73/100
4/4 [==============================] - 1s 179ms/step - loss: 0.3490 - val_loss: 0.0672
Epoch 74/100
4/4 [==============================] - 1s 178ms/step - loss: 0.3524 - val_loss: 0.0611
Epoch 75/100
4/4 [==============================] - 1s 173ms/step - loss: 0.3622 - val_loss: 0.0620
Epoch 76/100
4/4 [==============================] - 1s 177ms/step - loss: 0.3948 - val_loss: 0.0658
Epoch 77/100
4/4 [==============================] - 1s 170ms/step - loss: 0.2945 - val_loss: 0.0608
Epoch 78/100
4/4 [==============================] - 1s 171ms/step - loss: 0.3269 - val_loss: 0.0638
Epoch 79/100
4/4 [==============================] - 1s 171ms/step - loss: 0.3167 - val_loss: 0.0647
Epoch 80/100
4/4 [==============================] - 1s 181ms/step - loss: 0.3143 - val_loss: 0.0646
Epoch 81/100
4/4 [==============================] - 1s 191ms/step - loss: 0.3812 - val_loss: 0.0682
Epoch 82/100
4/4 [==============================] - 1s 245ms/step - loss: 0.3734 - val_loss: 0.0680
Epoch 83/100
4/4 [==============================] - 1s 246ms/step - loss: 0.3135 - val_loss: 0.0641
Epoch 84/100
4/4 [==============================] - 1s 208ms/step - loss: 0.2190 - val_loss: 0.0662
Epoch 85/100
4/4 [==============================] - 1s 173ms/step - loss: 0.3904 - val_loss: 0.0680
Epoch 86/100
4/4 [==============================] - 1s 211ms/step - loss: 0.3203 - val_loss: 0.0693
Epoch 87/100
4/4 [==============================] - 1s 263ms/step - loss: 0.3333 - val_loss: 0.0649
Epoch 88/100
4/4 [==============================] - 1s 172ms/step - loss: 0.2221 - val_loss: 0.0668
Epoch 89/100
4/4 [==============================] - 1s 214ms/step - loss: 0.3148 - val_loss: 0.0658
Epoch 90/100
4/4 [==============================] - 1s 170ms/step - loss: 0.4059 - val_loss: 0.0717
Epoch 91/100
4/4 [==============================] - 1s 170ms/step - loss: 0.2097 - val_loss: 0.0657
Epoch 92/100
4/4 [==============================] - 1s 170ms/step - loss: 0.2152 - val_loss: 0.0628
Epoch 93/100
4/4 [==============================] - 1s 171ms/step - loss: 0.2845 - val_loss: 0.0644
Epoch 94/100
4/4 [==============================] - 1s 168ms/step - loss: 0.2331 - val_loss: 0.0658
Epoch 95/100
4/4 [==============================] - 1s 178ms/step - loss: 0.3202 - val_loss: 0.0653
Epoch 96/100
4/4 [==============================] - 1s 176ms/step - loss: 0.1792 - val_loss: 0.0677
Epoch 97/100
4/4 [==============================] - 1s 170ms/step - loss: 0.3031 - val_loss: 0.0691
Epoch 98/100
4/4 [==============================] - 1s 174ms/step - loss: 0.2034 - val_loss: 0.0690
Epoch 99/100
4/4 [==============================] - 1s 170ms/step - loss: 0.2300 - val_loss: 0.0703
Epoch 100/100
4/4 [==============================] - 1s 169ms/step - loss: 0.2708 - val_loss: 0.0712
[14]:
m, v = model.predict(X_test.to_numpy())
[15]:
## TODO: Tidy this up a bit
plt.figure(figsize=(12, 8))
plt.subplot(2,2,1)
qplt.pcolormesh(m[0].collapsed('time', iris.analysis.MEAN))
plt.gca().set_title('Predicted')
plt.gca().coastlines()

plt.subplot(2,2,2)
qplt.pcolormesh(Y_test[0].collapsed('time', iris.analysis.MEAN))
plt.gca().set_title('Test')
plt.gca().coastlines()

plt.subplot(2,2,3)
qplt.pcolormesh((m.collapsed(['sample', 'time'], iris.analysis.MEAN)-Y_test.collapsed(['job', 'time'], iris.analysis.MEAN)), cmap='RdBu_r', vmin=-0.01, vmax=0.01)
plt.gca().coastlines()
plt.gca().set_title('Difference')
C:\Users\duncan\miniconda3\envs\gcem_dev\lib\site-packages\iris\coords.py:1193: UserWarning: Coordinate 'longitude' is not bounded, guessing contiguous bounds.
  'contiguous bounds.'.format(self.name()))
C:\Users\duncan\miniconda3\envs\gcem_dev\lib\site-packages\iris\coords.py:1193: UserWarning: Coordinate 'latitude' is not bounded, guessing contiguous bounds.
  'contiguous bounds.'.format(self.name()))
C:\Users\duncan\miniconda3\envs\gcem_dev\lib\site-packages\iris\coords.py:1193: UserWarning: Coordinate 'longitude' is not bounded, guessing contiguous bounds.
  'contiguous bounds.'.format(self.name()))
C:\Users\duncan\miniconda3\envs\gcem_dev\lib\site-packages\iris\coords.py:1193: UserWarning: Coordinate 'latitude' is not bounded, guessing contiguous bounds.
  'contiguous bounds.'.format(self.name()))
C:\Users\duncan\miniconda3\envs\gcem_dev\lib\site-packages\iris\coords.py:1410: UserWarning: Collapsing a non-contiguous coordinate. Metadata may not be fully descriptive for 'sample'.
  warnings.warn(msg.format(self.name()))
C:\Users\duncan\miniconda3\envs\gcem_dev\lib\site-packages\iris\coords.py:1410: UserWarning: Collapsing a non-contiguous coordinate. Metadata may not be fully descriptive for 'job'.
  warnings.warn(msg.format(self.name()))
C:\Users\duncan\miniconda3\envs\gcem_dev\lib\site-packages\iris\coords.py:1193: UserWarning: Coordinate 'longitude' is not bounded, guessing contiguous bounds.
  'contiguous bounds.'.format(self.name()))
C:\Users\duncan\miniconda3\envs\gcem_dev\lib\site-packages\iris\coords.py:1193: UserWarning: Coordinate 'latitude' is not bounded, guessing contiguous bounds.
  'contiguous bounds.'.format(self.name()))
[15]:
Text(0.5,1,'Difference')
../_images/examples_Emulating_using_ConvNets_12_2.png
[16]:
m, sd = model.batch_stats(get_random_params(3, int(1e5)), batch_size=1000)
100%|##########| 100000/100000 [26:03<00:00, 44.36sample/s]
[ ]: