Introduction
In this work, we utilize the Temporal Fusion Transformer (TFT) to achieve state-of-the-art performance in forecasting US county-level infections while enabling new forms of interpretability through analyzing complex spatiotemporal patterns. The proposed model (1) outperforms other popular deep learning models in all evaluation metrics for multivariate multi-horizon forecasting, (2) exhibits robust performance in predicting non-stationary trends of the infections at different waves of the COVID-19 pandemic, (3) interprets temporal patterns, such as weekly and holiday seasonality in reported cases, through multi-head attention weights, and (4) reveals spatial patterns using attention weights that are correlated to the infection spread. The model performs consistently across different counties, despite the large variation in infection rates, and can be easily extended to other datasets at the community level, as exemplified in the characteristics (e.g. population, health status, and socioeconomic factors).
Interactive Graph
Click counties on the below U.S. map to see how our model captures spatial patterns across urban and rural counties. Or Type in the County Name and State Name to see the one you are interested in! please be in the format of {State, County} e.g: Virginia, Madison
This interactive map is based on the findings of the paper "Interpreting County Level COVID-19 Infection and Feature Sensitivity using Deep Learning Time Series Models." The map visualizes COVID-19 case predictions at the county level across the United States, generated using a state-of-the-art machine learning model called the Temporal Fusion Transformer (TFT). The TFT model was trained on over two years of data, incorporating both static factors (e.g., demographics, socioeconomic conditions) and dynamic factors (e.g., vaccination rates, mobility patterns) to forecast daily infection trends. By selecting a county on the map or entering its name, you can explore how the model's predictions compare to actual reported cases over time. This visualization not only highlights the accuracy of the model but also provides insights into spatial patterns of COVID-19 spread across urban and rural areas. The tool aims to demonstrate how advanced machine learning techniques can be applied to public health challenges, offering both predictive power and interpretability for data-driven decision-making.
Folder Structure
- Archives: Unused codes.
- dataset_raw: Contains the collected raw dataset and the supporting files. To update use the Update dynamic dataset notebook. The static dataset is already updated till the onset of COVID-19 using Update static dataset notebook.
- papers: Related papers.
- Related Works: Contains the models and results used to compare the TFT performance with related works.
- TFT-PyTorch: Contains all codes and merged feature files used during the TFT experimentation setup and interpretation. For more details, check the README.md file inside it. The primary results are highlighted in results.md.
How to Reproduce
For detailed instructions on how to reproduce, follow the Reproduce.md file. In summary, it includes the following steps:
- Getting the env ready
- From scratch using Anaconda or pip and installing libraries using requirements.txt.
- Creating the containers (
Singularity
orDocker
). Definitions are given in singularity.def and Dockerfile. An already createdSingularity
container is hosted here.
- To reproduce tft experiments run the scripts in the TFT-pytorch folder.
- For the related works comparison run the scripts from Related Works folder.
- Note that, for Python path management, the scripts have to be run from their corresponding folder (not from this root).
- The are some notebooks available for most scripts too for easier debugging.
Results
Ground Truth and Benchmark
Temporal Patterns
Time series data typically exhibit various temporal patterns, such as trend, seasonal, and cyclic patterns. Here we investigate how well our TFT model can learn and interpret these patterns by conducting experiments on data with these patterns.
Cyclic holiday patterns (Thanksgiving, Christmas). During holidays, hospitals and COVID-19 test centers often have reduced staffing and operating hours, leading to fewer tests and reported cases. Leading to a drop in attention for those days.
Trend: TFT model's test performance on all US counties for additional data splits learning different infection trends.
Spatial Patterns
Features
Note that, past values of the target and known futures are also used as observed inputs by TFT.
Feature | Type | Update Frequency | Description/Rationale | Source(s) |
---|---|---|---|---|
Age Distribution (% age 65 and over) |
Static | Once | Aged 65 or Older from 2016-2020 American Community Survey (ACS). Older ages have been associated with more severe outcomes from COVID-19 infection. | 2020 SVI | Health Disparities (Uninsured) |
Percentage uninsured in the total civilian noninstitutionalized population estimate, 2016- 2020 ACS. Individuals without insurance are more likely to be undercounted in infection statistics, and may have more severe outcomes due to lack of treatment. | 2020 SVI |
Transmissible Cases | Observed | Daily | Cases from the last 14 days per 100k population. Because of the 14-day incubation period, the cases identified in that time period are the most likely to be transmissible. This metric is the number of such "contagious" individuals relative to the population, so a greater number indicates a more likely continued spread of disease. | USA Facts , 2020 SVI (for population estimate) |
Disease Spread | Cases that are from the last 14 days (one incubation period) divided by cases from the last 28 days . Because COVID-19 is thought to have an incubation period of about 14 days, only a sustained decline in new infections over 2 weeks is sufficient to signal a reduction in disease spread. This metric is always between 0 and 1, with values near 1 during the exponential growth phase, and declining linearly to zero over 14 days if there are no new infections. | USA Facts | ||
Social Distancing | Unacast social distancing scoreboard grade is assigned by looking at the change in overall distance traveled and the change in nonessential visits relative to baseline (previous year), based on cell phone mobility data. The grade is converted to a numerical score, with higher values being less social distancing (worse score) is expected to increase the spread of infection because more people are interacting with other. | Unacast | ||
Vaccination Full Dose (Series_Complete_Pop_Pct) |
Percent of people who are fully vaccinated (have a second dose of a two-dose vaccine or one dose of a single-dose vaccine) based on the jurisdiction and county where the recipient lives. | CDC | ||
SinWeekly | Known Future | Sin (day of the week / 7) . | Date | |
CosWeekly | Cos (day of the week / 7) . | |||
Case | Target | COVID-19 infection at county level. | USA Facts |
Contribute
- Please do not add temporarily generated files in this repository.
- Make sure to clean your temp files before pushing any commits.
- In the .gitignore file you will find some paths in this directory are excluded from git tracking. So if you create anything in those folders, they won't be tracked by git.
- To check which files git says untracked:
git status -u
. - If you have folders you want to exclude, add the path in
.gitignore
, thengit add .gitignore
. Check again withgit status -u
if it is still being tracked.
- To check which files git says untracked:
Citation
@INPROCEEDINGS{islam2023interpreting, author={Islam, Md Khairul and Liu, Yingzheng and Erkelens, Andrej and Daniello, Nick and Marathe, Aparna and Fox, Judy}, booktitle={2023 IEEE International Conference on Digital Health (ICDH)}, title={Interpreting County-Level COVID-19 Infections using Transformer and Deep Learning Time Series Models}, year={2023}, pages={266-277}, doi={10.1109/ICDH60066.2023.00046}