Browse Source

annotated cases tornado plot

main
parent
commit
e3e1a2cfa0
  1. 38
      tornado/tornado_plot.py

38
tornado/tornado_plot.py

@ -23,11 +23,6 @@ def cm2inch(*tupl):
# %% Countries
countries = {
'cn': {'name': 'China'},
'it': {'name': 'Italy'},
'fr': {'name': 'France'},
'de': {'name': 'Germany'},
'us': {'name': 'USA'},
'uk': {'name': 'UK'}
}
@ -41,9 +36,9 @@ for country_code, country_data in countries.items():
df.index = pd.to_datetime(df.index)
# Process data
df['daily_deaths'] = df['deaths'].diff().abs() # .abs() dirty trick to prevent negative outliers
df['daily_deaths_avg'] = df['daily_deaths'].rolling(7).mean()
df['death_change'] = df['daily_deaths_avg'].diff()
df['daily_cases'] = df['cases'].diff().abs() # .abs() dirty trick to prevent negative outliers
df['daily_cases_avg'] = df['daily_cases'].rolling(7).mean()
df['cases_change'] = df['daily_cases_avg'].diff()
# Smoothing
df = df.resample('4H').asfreq()
@ -57,22 +52,37 @@ fig, ax = plt.subplots(figsize=cm2inch(15,15))
for country_code, country_data in countries.items():
df = country_data['dataframe']
line, = ax.plot(df['death_change'], df['daily_deaths_avg'], lw=0.5, label=country_data['name'])
#plt.plot(df['death_change'], df['daily_deaths_avg'], 'ob') # dots for debugging
line, = ax.plot(
df['cases_change'],
df['daily_cases_avg'],
lw=0.5, label=country_data['name'])
# select which dates to label
df['month'] = df.index.month
df['month_change'] = df['month'].diff()
dates = df['month_change'] == 1
for index, row in df[dates].iterrows():
month_start = df['month_change'] == 1
mar_onwards = df['month'] >= 3
labeldates = pd.concat([df[month_start & mar_onwards], df.tail(1)])
# date labels
for index, row in labeldates.iterrows():
date_text = row.name.strftime(format='%d %b')
ax.annotate(date_text,
(row['death_change'], row['daily_deaths_avg']),
color=line.get_color())
xy=(row['cases_change'], row['daily_cases_avg']),
xycoords='data',
xytext=(0,10),
textcoords='offset points',
ha='center',
color=line.get_color(),
bbox=dict(boxstyle='square, pad=0.5', alpha=0.7, fc='white', ec='white'))
# date markers
ax.scatter(row['cases_change'], row['daily_cases_avg'], color=line.get_color(), s=4)
plt.axvline(x=0, c='black', lw=1, ls=':')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set(ylabel="Daily COVID-19 Cases", xlabel="Increase or decrease in cases per day")
ax.legend(loc='upper center', ncol=3, bbox_to_anchor=(0.5,1.15))
plt.show()

Loading…
Cancel
Save