Training a 2D Convolutional Neural Network for hyperspectral data classification¶
In this notebook, you will train and apply a two-dimensional convolutional neural network for classification of hyperspectral data from Bílá Louka, Krkonoše mountains, Czechia. Please start by reading about the dataset and area of interest here.
Prerequisities (This notebook can be run either online using Google Colab or on your local machine)
- A Google account for accessing Google Colab (link to notebook). If running the exercise in Google Colab, please copy the notebook to your own Google Drive and follow the exercise, you don't need to download the dataset.
or
A Python environment with the necessary libraries (manual). Download this notebook (click the download button in the top right corner of the page) and follow the exercise.
Downloaded data (module4.zip/theme4_exercise_machine_learning) The dataset consists of:
- Hyperspectral RPAS imagery of Bílá Louka, Czechia (50.728N, 15.682E) acquired in August of 2020 and resampled to 54 spectral bands with ground sampling distance of 9 cm: BL_202008_imagery.tif
- a raster with reference data: BL_202008_reference.tif
- Pretrained models and corresponding classified rasters: /sample_results/*
Tasks
- Preprocess imagery for Deep Learning
- Classify the hyperspectral image using 2D CNN
- Observe how hyperparameter values alter classification results
- Evaluate your results and compare to our pretrained classifier
- Optional: Classify a urban scene
Structure of this exercise¶
What are you going to encounter during this exercise.
- Load libraries, set paths
- Load and Preprocess training data
- Neural Network Definition / Training
- Apply Network
- Evaluate Result
- Sample Solution
0. Load external libraries and set paths¶
First, we need to import external libraries:
torch, torch.nn, torch.optim, torchnet - Pytorch related libraries for deep learning
numpy - Arrays to hold our data
matplotlib.pyplot - Draw images
sklearn.model_selection - Cross-validation implemented in scikit-learn
sklearn.metrics - Compute accuracy metrics using scikit-learn
time.perf_counter - Track how long individual functions take to run
os.path - Path manipulation
tqdm - show progress bars during training
etrainee_m4_utils.image_preprocessing - Our library holding functions for image tiling, preprocessing, etc.
etrainee_m4_utils.inference_utils - Our library for correctly exporting classifed images
etrainee_m4_utils.visualisation_utils - Our library for visualising the data
Two external libraries are not imported directly in this notebook, but are used by functions in image_preprocessing and inference_utils:
- gdal - Manipulates spatial data
- scipy.io - Reads .mat files
import torch
import torch.nn as nn
import torch.optim as optim
import torchnet as tnt
import numpy as np
import matplotlib.pyplot as plt
from os.path import join
from time import perf_counter
from tqdm import notebook as tqdm
from sklearn.metrics import classification_report
from sklearn.model_selection import KFold, StratifiedKFold, train_test_split
from etrainee_m4_utils import image_preprocessing
from etrainee_m4_utils import inference_utils
from etrainee_m4_utils import visualisation_utils
# GLOBAL SETTINGS
plt.rcParams['figure.figsize'] = [5, 5]
np.set_printoptions(precision=2, suppress=True) # Array print precision
# Set dataset name (used by visualisation functions) - 'bila_louka' or 'pavia_centre'
# default: 'pavia_centre'
ds_name = 'bila_louka'
# Get a list of class names
_, class_names = visualisation_utils._create_colorlist_classnames(ds_name=ds_name)
Please fill correct paths to your training and reference rasters (just pointing the root_path variable to the project folder should do):
root_path = 'C:/folder/where/this/project/is/saved'
# PATHS TO TRAINING DATA
# Krkonose
imagery_path = join(root_path, 'BL_202008_imagery.tif')
reference_path = join(root_path, 'BL_202008_reference.tif')
# Pavia
#imagery_path = join(root_path, 'Pavia.mat')
#reference_path = join(root_path, 'Pavia_gt.mat')
# PATH TO SAVE MODELS
model_save_folder = join(root_path, 'models')
# PATH TO SAVE CLASSIFIED IMAGE
out_path = join(root_path, 'results/Krkonose_2D_CNN.tif')
# PATH TO THE SAMPLE RESULTS
sample_result_path = join(root_path, 'sample_results/2D_CNN_sample_result.tif')
1. Loading and preprocessing training data¶
1.1. Data loading into NumPy¶
Let's start by reading an image into a numpy array, we do this in the background using either scipy or GDAL.
The result of our function is a dictionary named loaded_raster, which contains two numpy arrays under keys imagery and reference. As we can see, the loaded hyperspectral dataset has 1088 by 1088 pixels with 54 spectral bands. The raster containing our reference data has the same dimensions in height and width.
For loading most raster datasets, we created a read_gdal() function in the image_preprocessing module. But loading .mat files for the Pavia City Centre requires a specific function (read_pavia_centre()). Both read_pavia_centre() and read_gdal() return a dictionary containing two numpy arrays with keys imagery and reference.
If using the Pavia City Centre dataset, you may notice that the original image has a shape of (1096, 1096, 102), but to make the data easier to tile, we crop the image to (1088, 1088, 102) here.
loaded_raster = image_preprocessing.read_gdal(imagery_path, reference_path)
# loaded_raster = image_preprocessing.read_pavia_centre(imagery_path,
# reference_path, out_shape=(1088, 1088, 102))
print(f'Tiled imagery shape {loaded_raster["imagery"].shape}')
print(f'Tiled reference shape {loaded_raster["reference"].shape}')
visualisation_utils.show_img_ref(loaded_raster['imagery'][:, :, [25, 15, 5]],
loaded_raster['reference'],
ds_name=ds_name)
1.2. Image tiling¶
We have our data loaded into a numpy array, the next step is to divide the image into individual tiles, which will be the input for our neural network.
As we want to perform convolution only in the spatial dimensions, we need to divide the hyperspectral image into tiles of a given shape. Standard tile sizes are based on multiples of two, for example 2^8 = 256.
Overlap is important as there could be inconsistencies ("jagged" edges) on boundaries of classified tiles, we therefore combine results of overlapping tiles to avoid the inconsistencies.
As you can see, tiling procedure transformed the original 1088x1088 image into 256 tiles of 128x128 pixels, with the original number of spectral bands.
tile_shape = (128, 128)
overlap = 64
dataset_tiles = image_preprocessing.tile_training(loaded_raster,
tile_shape,
overlap)
print(f'Tiled imagery shape {dataset_tiles["imagery"].shape}')
print(f'Tiled reference shape {dataset_tiles["reference"].shape}')
1.3. Tile filtration¶
However, some of the created tiles do not contain training data, we therefore need to filter them and only keep the tiles with a field-collected reference.
This process reduces the size of our dataset from 256 to 226 tiles.
filtered_tiles = image_preprocessing.filter_useful_tiles(dataset_tiles,
nodata_vals=[0],
is_training=True)
print(f'Filtered imagery shape {filtered_tiles["imagery"].shape}')
print(f'Filtered reference shape {filtered_tiles["reference"].shape}')
1.4. Data normalization¶
After filtering the tiles to only include training data, we can move onto a final part of the preprocessing - data normalization. In Machine Learning, it is common to normalize all data before classification.
normalized_tiles, unique, counts = image_preprocessing.normalize_tiles(filtered_tiles, nodata_vals=[0], is_training=True)
print(f'Preprocessed imagery shape {normalized_tiles["imagery"].shape}')
print(f'Preprocessed reference shape {normalized_tiles["reference"].shape}')
1.5. Splitting data for training/testing¶
Our reference dataset has to be split into three groups:
training / validation / test
In this step we divide the data into train+val and test groups. The variable train_fraction is used to establish how big of a fraction of the reference dataset is used for training and validation. By default, a third of the referece data is useed for training and validation, while the other two thirds are used for testing model preformance. The splitting is done using a scikit learn function train_test_split, which performs a stratified sampling from all classes.
train_fraction = 1/3
X_train, X_test, y_train, y_test = train_test_split(normalized_tiles['imagery'], normalized_tiles['reference'], train_size=train_fraction)
training = {'imagery': X_train, 'reference': y_train}
test = {'imagery': X_test, 'reference': y_test}
1.6. Conversion to Pytorch Tensors¶
The resulting preprocessed tiles are subsequently transformed from numpy arrays into pytorch tensors for the training.
dataset = tnt.dataset.TensorDataset([training['imagery'],
training['reference']])
print(f'Class labels: \n{unique}\n')
print(f'Number of pixels in a class: \n{counts}')
2. Neural network definition¶
After preprocessing our data, we can move onto defining our neural network and functions for training. You can either train your own neural network or use the one we already trained for you (sample_results/2D_CNN_sample_trained.pt). In case you are using the pretrained network, please run only the following code snippet (2.1.) and skip ahead to section 3.
2.1. Network structure¶
Our network is named SpatialNet, and it's based on the popular U-Net / SegNet architectures. The Structure of our network is defined in the SpatialNet class, which has three methods:
- __init__ - This method runs automatically when defining an instance of the class, it defines individual layers of the networks (2D convolutions, transposed convolutions, maxpooling and also a dropout layer).
- init_weights - Randomly initialising network weights based on a normal distribution.
- forward - Defining how data should flow through the network during a forward pass (network structure definition). The PyTorch library automatically creates a method for backward passes based on this structure.
class SpatialNet(nn.Module):
"""U-Net for semantic segmentation."""
def __init__(self, args):
"""
Initialize the U-Net model.
n_channels, int, number of input channel
size_e, int list, size of the feature maps of convs for the encoder
size_d, int list, size of the feature maps of convs for the decoder
n_class = int, the number of classes
"""
super(SpatialNet, self).__init__(
) # necessary for all classes extending the module class
self.maxpool = nn.MaxPool2d(2, 2, return_indices=False)
self.dropout = nn.Dropout2d(p=0.5, inplace=True)
self.n_channels = args['n_channel']
self.size_e = args['size_e']
self.size_d = args['size_d']
self.n_class = args['n_class']
def conv_layer_2d(in_ch, out_ch, k_size=3, conv_bias=False):
"""Create default conv layer."""
return nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size=k_size,
padding=1, padding_mode='reflect',
bias=conv_bias),
nn.BatchNorm2d(out_ch), nn.ReLU())
# Encoder layer definitions
self.c1 = conv_layer_2d(self.n_channels, self.size_e[0])
self.c2 = conv_layer_2d(self.size_e[0], self.size_e[1])
self.c3 = conv_layer_2d(self.size_e[1], self.size_e[2])
self.c4 = conv_layer_2d(self.size_e[2], self.size_e[3])
self.c5 = conv_layer_2d(self.size_e[3], self.size_e[4])
self.c6 = conv_layer_2d(self.size_e[4], self.size_e[5])
# Decoder layer definitions
self.trans1 = nn.ConvTranspose2d(self.size_e[5], self.size_d[0],
kernel_size=2, stride=2)
self.c7 = conv_layer_2d(self.size_d[0], self.size_d[1])
self.c8 = conv_layer_2d(self.size_d[1], self.size_d[2])
self.trans2 = nn.ConvTranspose2d(self.size_d[2], self.size_d[3],
kernel_size=2, stride=2)
self.c9 = conv_layer_2d(self.size_d[3], self.size_d[4])
self.c10 = conv_layer_2d(self.size_d[4], self.size_d[5])
# Final classifying layer
self.classifier = nn.Conv2d(self.size_d[5], self.n_class,
1, padding=0)
# Weight initialization
self.c1[0].apply(self.init_weights)
self.c2[0].apply(self.init_weights)
self.c3[0].apply(self.init_weights)
self.c4[0].apply(self.init_weights)
self.c5[0].apply(self.init_weights)
self.c6[0].apply(self.init_weights)
self.c7[0].apply(self.init_weights)
self.c8[0].apply(self.init_weights)
self.c9[0].apply(self.init_weights)
self.c10[0].apply(self.init_weights)
self.classifier.apply(self.init_weights)
# Put the model on GPU memory
if torch.cuda.is_available():
self.cuda()
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
def init_weights(self, layer): # gaussian init for the conv layers
"""Initialise layer weights."""
nn.init.kaiming_normal_(
layer.weight, mode='fan_out', nonlinearity='relu')
def forward(self, input_data):
"""Define model structure."""
# Encoder
# level 1
x1 = self.c2(self.c1(input_data))
x2 = self.maxpool(x1.clone())
# level 2
x3 = self.c4(self.c3(x2.clone()))
x4 = self.maxpool(x3.clone())
# level 3
x5 = self.c6(self.c5(x4.clone()))
# Decoder
# level 2
y4 = self.trans1(x5.clone())
y3 = self.c8(self.c7(y4.clone()))
# level 1
y2 = self.trans2(y3.clone())
y1 = self.c10(self.c9(y2.clone()))
# Output
out = self.classifier(self.dropout(y1.clone()))
return out
2.2. Functions for network training¶
Training the network is handled by four functions:
- augment - Augments the training data by adding random noise.
- train - Trains the network for one epoch.
- eval - Evaluates the results on a validation set.
- train_full - Performs the full training loop.
augment takes in the training tile and the corresponding reference labels. It then adds a random value (taken from a normal distribution) at each wavelength and thus slightly modifies the training data. Change tile_number to see the augmentation effect for different tiles.
def augment(obs, g_t):
"""the data augmentation function, introduces random noise and rotation"""
sigma, clip= 0.005, 0.02
# Random noise
rand = torch.clamp(torch.mul(sigma, torch.randn([1, 54, tile_shape[0],tile_shape[1]])), -clip, clip)
obs = torch.add(obs, rand)
# Random rotation 0 90 180 270 degree
n_turn = np.random.randint(4) #number of 90 degree turns, random int between 0 and 3
obs = torch.rot90(obs, n_turn, dims=(2,3))
g_t = torch.rot90(g_t, n_turn, dims=(1,2))
return obs, g_t
tile_number = 60
visualisation_utils.show_augment_spatial(training, tile_number,
augment, ds_name=ds_name)
train trains the network for one epoch. This function contains a for loop, which loads the training data in individual batches. Each batch of training data goes through the network, after which we compute the loss function (cross-entropy). Last step of training is performing an optimiser step, which changes the networks heights.
eval evaluates the results on a validation set, should be done periodically during training to check for overfitting.
train_full performs the full training loop. It first initialises the model and optimiser. Then the train function is called in a loop, with periodic evaluation on the validation set.
def train(model, optimizer, args):
"""train for one epoch"""
model.train() #switch the model in training mode
#the loader function will take care of the batching
loader = torch.utils.data.DataLoader(dataset, batch_size=args['batch_size'], sampler=args['train_subsampler'])
loader = tqdm.tqdm(loader, ncols=500)
#will keep track of the loss
loss_meter = tnt.meter.AverageValueMeter()
for index, (tiles, gt) in enumerate(loader):
optimizer.zero_grad() #put gradient to zero
tiles, gt = augment(tiles, gt)
if torch.cuda.is_available():
pred = model(tiles.cuda()) #compute the prediction
else:
pred = model(tiles)
loss = nn.functional.cross_entropy(pred.cpu(),gt, weight=args['class_weights'])
loss.backward() #compute gradients
for p in model.parameters(): #we clip the gradient at norm 1
p.grad.data.clamp_(-1, 1) #this helps learning faster
optimizer.step() #one SGD step
loss_meter.add(loss.item())
return loss_meter.value()[0]
def eval(model, sampler):
"""eval on test/validation set"""
model.eval() #switch in eval mode
loader = torch.utils.data.DataLoader(dataset, batch_size=1, sampler=sampler)
loader = tqdm.tqdm(loader, ncols=500)
loss_meter = tnt.meter.AverageValueMeter()
with torch.no_grad():
for index, (tiles, gt) in enumerate(loader):
if torch.cuda.is_available():
pred = model(tiles.cuda()) #compute the prediction
else:
pred = model(tiles)
loss = nn.functional.cross_entropy(pred.cpu(), gt)
loss_meter.add(loss.item())
return loss_meter.value()[0]
def train_full(args):
"""The full training loop"""
#initialize the model
model = SpatialNet(args)
print(f'Total number of parameters: {sum([p.numel() for p in model.parameters()])}')
#define the Adam optimizer
optimizer = optim.Adam(model.parameters(), lr=args['lr'])
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args['scheduler_milestones'],
gamma=args['scheduler_gamma'])
train_loss = np.empty(args['n_epoch'])
test_epochs = []
test_loss = []
for i_epoch in range(args['n_epoch']):
#train one epoch
print(f'Epoch #{str(i_epoch+1)}')
train_loss[i_epoch] = train(model, optimizer, args)
scheduler.step()
# Periodic testing on the validation set
if (i_epoch == args['n_epoch'] - 1) or ((i_epoch + 1) % args['n_epoch_test'] == 0):
print('Evaluation')
loss_test = eval(model, args['test_subsampler'])
test_epochs.append(i_epoch + 1)
test_loss.append(loss_test)
plt.figure(figsize=(10, 10))
plt.subplot(1,1,1,ylim=(0,5), xlabel='Epoch #', ylabel='Loss')
plt.plot([i+1 for i in range(args['n_epoch'])], train_loss, label='Training loss')
plt.plot(test_epochs, test_loss, label='Validation loss')
plt.legend()
plt.show()
print(train_loss)
print(test_loss)
args['loss_test'] = test_loss[-1]
return model
2.3. Hyperparameter definition¶
Training networks requires first setting several hyperparameters, please feel free to play around with them and try different values for the number of training epochs, learning rate or batch size.
n_channel - number of channels, set to 1 for our task
n_class - number of classification classes
size_e - number of filters in each NN layer of the encoder
size_d - number of filters in each NN layer of the decoder
crossval_nfolds - Number of folds for crossvalidation
n_epoch_test - after how many training epochs to evaluate on the validation set
scheduler_milestones - after how many epochs do we reduce the training rate
scheduler_gamma - by what factor do we reduce the training rate
class_weights - training weights for individual classes, used to offset imbalanced class distribution
n_epoch - how many epochs are performed during training
lr - how fast can individual network parameters change during one training epoch
batch_size - how many tiles should be included in each gradient descent step
args = { # Dict to store all model parameters
'n_channel': 54,
'n_class': len(unique),
'size_e': [64,64,128,128,256,256],
'size_d': [256,128,128,128,64,64],
'crossval_nfolds': 3,
'n_epoch_test': 2,
'scheduler_milestones': [50,75,90],
'scheduler_gamma': 0.3,
#'class_weights': torch.tensor([0.0, 0.2, 0.34, 0.033, 0.16, 0.14, 0.03, 0.014, 0.023, 0.06]),
'class_weights': torch.tensor([.0, .15, .05, .3, .05, .05, .1, .3]),
'n_epoch': 50,
'lr': 1e-6,
'batch_size': 4,
}
print(f'''Number of models to be trained:
{args['crossval_nfolds']}
Number of spectral channels:
{args['n_channel']}
Initial learning rate:
{args['lr']}
Batch size:
{args['batch_size']}
Number of training epochs:
{args['n_epoch']}''')
2.4 Network training¶
Run the training procedure using set hyperparameters, evaluate the training/validation loss graphs produced during training to adjust hyperparameter values.
## Training a 2D network
kfold = KFold(n_splits = args['crossval_nfolds'], shuffle=True)
trained_models = []
for fold, (train_ids, test_ids) in enumerate(kfold.split(dataset)):
print(f'Training starts for model number {str(fold+1)}')
a = perf_counter()
args['train_subsampler'] = torch.utils.data.SubsetRandomSampler(train_ids)
args['test_subsampler'] = torch.utils.data.SubsetRandomSampler(test_ids)
trained_models.append((train_full(args), args['loss_test']))
state_dict_path = join(model_save_folder, f'2D_CNN_fold_{str(fold)}.pt')
torch.save(trained_models[fold][0].state_dict(), state_dict_path)
print(f'Model saved to: {state_dict_path}')
print(f'Training finished in {str(perf_counter()-a)}s')
print('\n\n')
print(f'Resulting loss for individual folds: \n{[i for _, i in trained_models]}')
print(f'Mean loss across all folds: \n{np.mean([i for _, i in trained_models])}')
3. Applying the network¶
3.1. Loading a trained model¶
# Parameters for model definition
args = {
'n_class': 8,
'n_channel': 54,
'size_e': [64,64,128,128,256,256],
'size_d': [256,128,128,128,64,64]
}
# Select which model to load by using a different filename index
state_dict_path = join(model_save_folder, '2D_CNN_fold_0.pt')
# Load the model
model = SpatialNet(args)
model.load_state_dict(torch.load(state_dict_path))
model.eval()
3.2. Loading and preprocessing the data¶
Load a raster to classify. This can be the one that we used for training, but it can also be a different raster with the same number of bands.
By default, the training raster (imagery_path) is used.
tile_shape = (128, 128)
overlap = 64
start = perf_counter()
# Load raster
raster_orig = image_preprocessing.read_gdal_with_geoinfo(imagery_path)
#raster_orig = image_preprocessing.read_pavia_centre(imagery_path, out_shape=(1088, 1088, 102))
# Split raster into tiles
dataset_full_tiles = image_preprocessing.run_tiling_dims(raster_orig['imagery'],
out_shape=tile_shape, out_overlap=overlap)
# Normalize tiles
dataset_full = image_preprocessing.normalize_tiles(dataset_full_tiles,
nodata_vals=[0])
# Convert to Pytorch TensorDataset
dataset = tnt.dataset.TensorDataset(dataset_full['imagery'])
print('')
print(f'Loading and preprocessing the imagery took {visualisation_utils.sec_to_hms(perf_counter() - start)}.')
3.3. Applying the CNN¶
The following snippet applies the CNN.
timer_start = perf_counter()
arr_class = inference_utils.combine_tiles_2d(model, dataset,
tile_shape, overlap,
dataset_full_tiles['dimensions'])
print(f'Model applied to all tiles in {visualisation_utils.sec_to_hms(perf_counter() - timer_start)}.')
You can also visualise the result:
visualisation_utils.show_classified(raster_orig['imagery'][:, :, [25, 15, 5]],
loaded_raster['reference'], arr_class,
ds_name=ds_name)
3.4. Export Resulting Raster¶
Export the resulting classified raster into out_path for distribution or further analysis (e.g. validation in GIS).
print(f'The classified result gets saved to {out_path}')
inference_utils.export_result(out_path, arr_class, raster_orig['geoinfo'])
4. Evaluate Classification Result¶
4.1. Apply classifier to the test dataset¶
The classified raster is comapred to the test reference data. The reference data is saved in the test dictionary with keys imagery and reference.
predicted_arr = inference_utils.classify_tiles(model, test['imagery'], mode='2D')
# Reshaping both rasters to 1D
test_flat = test['reference'].reshape(test['reference'].shape[0] * test['reference'].shape[1] * test['reference'].shape[2])
pred_flat = predicted_arr.reshape(test['reference'].shape[0] * test['reference'].shape[1] * test['reference'].shape[2])
# Filtering to only include pixels with reference data
pred_filtered = pred_flat[test_flat > 0]
test_filtered = test_flat[test_flat > 0]
4.2. Compute Accuracy Metrics¶
print(classification_report(test_filtered, pred_filtered, zero_division=0,
target_names=class_names))
4.3. Show Confusion Matrix¶
visualisation_utils.show_confusion_matrix(test_filtered, pred_filtered,
ds_name=ds_name)
5. Sample Solution¶
We have generated this result using these training parameters (please note that just using the same training parameters will not yield the same result, as the network is randomly initailised):
- Number of epochs: 100
- Batch Size: 4
- Learning Rate: 1e-6
- Learning Rate reduced after Epochs #: [50,75,90]
- Learning Rate reduced by (scheduler_gamma): 0.3
- Class Weights: [0, 0.197, 0.343, 0.032, 0.166, 0.141, 0.029, 0.013, 0.022, 0.057]
- Number of Convolutional filters in the Encoder (size_e):
[256,256,512,512,1024,1024] - Number of Convolutional filters in the Decoder (size_d):
[1024,512,512,512,256,256]
# Read test reference
test_arr = image_preprocessing.read_gdal(imagery_path, test_path)
test_flat = test_arr['reference'].reshape(
test_arr['reference'].shape[0]*test_arr['reference'].shape[1])
test_filtered = test_flat[test_flat > 0]
# Read sample result
sample_arr = image_preprocessing.read_gdal(imagery_path, sample_result_path)
sample_flat = sample_arr['reference'].reshape(sample_arr['reference'].shape[0] * sample_arr['reference'].shape[1])
sample_filtered = sample_flat[test_flat > 0]
# Visualise the sample result
visualisation_utils.show_classified(loaded_raster['imagery'][:, :, [25, 15, 5]],
loaded_raster['reference'],
sample_arr['reference'],
ds_name=ds_name)
# Print a classification report for the sample result
print(classification_report(test_filtered, sample_filtered,
target_names=class_names[1:]))
# Show a Confusion matrix for the sample result
visualisation_utils.show_confusion_matrix(test_filtered, sample_filtered,
ds_name=ds_name)
Optional: Classify a urban scene¶
Try using this notebook to classify a urban scene (Pavia City Centre). Reflect on how the different landscape structure and clearer class definitions influence the classification result.
Pavia city centre is a common benchmark for hyperspectral data classification and can be obtained from http://www.ehu.eus/ccwintco/index.php/Hyperspectral_Remote_Sensing_Scenes#Pavia_Centre_and_University. You will need to change the paths to input data and use the read_pavia_centre method to load the matlab matrices into numpy.
Return to exercises