Explainable Deep Learning for Alzheimer's Disease Diagnosis and Mild-Cognitive Impairment Prognosis

Poster No:

1391 

Submission Type:

Abstract Submission 

Authors:

Sophie Martin1, Francesca Biondo2, Florence Townend3, An Zhao1, Frederik Barkhof4, James Cole1

Institutions:

1University College London, London, London, 2UCL, London, London, 3University College London, London, Greater London, 4Amsterdam University Medical Centre, Amsterdam, Noord-Holland

First Author:

Sophie Martin, MRes  
University College London
London, London

Co-Author(s):

Francesca Biondo, PhD  
UCL
London, London
Florence Townend  
University College London
London, Greater London
An Zhao  
University College London
London, London
Frederik Barkhof, MD, Ph. D  
Amsterdam University Medical Centre
Amsterdam, Noord-Holland
James Cole, PhD  
University College London
London, London

Introduction:

Deep learning provides innovative solutions for predicting Alzheimer's disease (AD) before significant cognitive decline, a key challenge of the modern era. In this work, we address two limitations of existing studies: Many studies are biased by relying only on large, curated, homogenous research datasets to train and evaluate models, limiting generalisability (Martin et al., 2023). Models are also often limited by a lack of transparency, hindering integration with existing diagnostic pipelines. Here, we leveraged a non-harmonised, multi-site dataset and the use of explainable AI methods to overcome these issues.

Methods:

3,060 3D T1-weighted MRI scans from the National Alzheimer's Coordinating Center were used to train and evaluate the models on held-out test sets. Minimal pre-processing was applied: each scan was affinely registered to the MNI152 template using EasyReg (Hoffmann et al., 2022; Iglesias et al., 2021). We used MRIqc image quality metrics and visual assessment to remove any poorly quality scans (Esteban et al., 2017). The dataset is described in Figure 1a.

We implemented two state-of-the-art neural networks, a ResNet (He et al., 2015; Wightman, Touvron, et al., 2021) and Vision Transformer (ViT) (Dosovitskiy et al., 2020; Steiner et al., 2021), to classify AD dementia patients versus healthy controls (diagnosis). To overcome the challenges of training 3D ViTs on small datasets, we investigated the use of an ensemble approach, by fine-tuning 2D models pretrained on ImageNet (Wightman, Ha, et al., 2021) using single slices along three planes: axial, sagittal and coronal. A triplanar ensemble model was produced by training a multi-layer perceptron using the outputs from each single-plane model. Additionally, we explored the combination of imaging-model based probabilities with non-imaging features in a Random Forest (RF) to assess whether this improved performance. These models were benchmarked against a 3D ResNet, as well as a RF with only non-imaging features.

We also evaluated each model on whether participants diagnosed with mild cognitive impairment (MCI) at baseline will progress to AD within 3 years (prognosis) to examine whether the learned features are transferable to an unseen, more challenging task. To explain the models, we applied several post-hoc explanation methods to visualise the most important features.

Results:

Figure 1c shows that triplanar models provide a small improvement in the diagnostic balanced accuracy (BACC) +0.8% and +0.1% for the ViT and ResNet respectively. However, the 3D ResNet achieved a BACC 83.4%, highlighting the importance of contextual information across the brain volume.

Combining the output of the imaging models with non-imaging data increased the diagnostic BACC to 95.3%, whilst balancing precision and recall better than using only non-imaging features. The advantage of including imaging data is most apparent in the absence of a bed-side screening tool for Alzheimer dementia (MMSE score). Moreover, for MCI prognosis, the multi-modal RF model with ResNet outputs showed a clear improvement in performance compared to non-imaging features alone, increasing the BACC from 65.5% to 73.1%.
Supporting Image: Figure1.png
Supporting Image: FigureTwo.png
 

Conclusions:

Deep learning models trained on more representative and heterogeneous datasets can still produce high diagnostic performance. A model trained only on healthy controls and AD patients was able to predict MCI conversion with a BACC of 75.3%, which is comparable with previous studies trained directly on this task. Vision Transformers achieve comparable diagnostic performance to ResNets with fewer parameters (5M and 11M respectively) but did not generalise as well to the prognosis task. Explanation methods can be used to identify salient brain regions and highlight the utility of neuroanatomical information over non-imaging features (Fig 2). Heatmaps were most consistent across methods for the ResNet model and highlighted relevant features such as ventricular atrophy.

Disorders of the Nervous System:

Neurodegenerative/ Late Life (eg. Parkinson’s, Alzheimer’s) 2

Modeling and Analysis Methods:

Classification and Predictive Modeling 1

Keywords:

Degenerative Disease
Machine Learning
MRI
Other - Explainable AI

1|2Indicates the priority used for review

Provide references using author date format

Dosovitskiy, A. (2020). An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale, arXiv.
Esteban, O. (2017). MRIQC: Advancing the automatic prediction of image quality in MRI from unseen sites, PLoS One, vol. 12, no. 9, e0184661.
He, K. (2015). Deep Residual Learning for Image Recognition, arXiv.
Hoffmann, M. (2022). SynthMorph: Learning Contrast-Invariant Registration Without Acquired Images, IEEE Transactions on Medical Imaging, vol. 41, no. 3, pp. 543-558.
Iglesias, J. E. (2021). Joint super-resolution and synthesis of 1 mm isotropic MP-RAGE volumes from clinical MRI exams with scans of different orientation, resolution and contrast. NeuroImage, vol. 237, pp. 118206-118206.
Martin, S. A. (2023). Interpretable machine learning for dementia: A systematic review, Alzheimer's & Dementia, vol. 19, no. 5, pp. 2135-2149.
Steiner, A. (2021). How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers, arXiv.
Wightman, R. (2021). rwightman/pytorch-image-models: Minor release, Zenodo.
Wightman, R. (2021). ResNet strikes back: An improved training procedure in timm, arXiv.