HW1: COVID-19 Cases Prediction (Regression)

Platform : Kaggle

Sample Code : Google Colab

Objectives:

  • Solve a regression problem with deep neural network. (DNN)
  • Understand basic DNN training tips.
  • Get familiar with PyTorch.

0. Prepare

1
2
import warnings
warnings.filterwarnings('ignore')

预览数据

1
2
tr_path = 'input/covid.train.csv'
tt_path = 'input/covid.test.csv'
1
2
3
4
5
6
# 预览数据
import pandas as pd

data = pd.read_csv(tr_path)
print("Shape:",data.shape)
data.iloc[:,40:].describe() # 40 列之后
Shape: (2700, 95)
WI cli ili hh_cmnty_cli nohh_cmnty_cli wearing_mask travel_outside_state work_outside_home shop restaurant spent_time large_event public_transit anxious depressed felt_isolated worried_become_ill worried_finances tested_positive cli.1 ili.1 hh_cmnty_cli.1 nohh_cmnty_cli.1 wearing_mask.1 travel_outside_state.1 work_outside_home.1 shop.1 restaurant.1 spent_time.1 large_event.1 public_transit.1 anxious.1 depressed.1 felt_isolated.1 worried_become_ill.1 worried_finances.1 tested_positive.1 cli.2 ili.2 hh_cmnty_cli.2 nohh_cmnty_cli.2 wearing_mask.2 travel_outside_state.2 work_outside_home.2 shop.2 restaurant.2 spent_time.2 large_event.2 public_transit.2 anxious.2 depressed.2 felt_isolated.2 worried_become_ill.2 worried_finances.2 tested_positive.2
count 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000 2700.000000
mean 0.025185 0.991587 1.016136 29.442496 24.323054 89.682322 8.894498 31.703307 55.277153 16.694342 36.283177 10.352273 2.393285 18.074684 13.075498 19.213321 64.633769 44.519474 16.300893 0.994568 1.019135 29.529305 24.402875 89.736737 8.861371 31.664651 55.198075 16.635440 36.176886 10.304595 2.389372 18.071667 13.067127 19.228457 64.734139 44.544124 16.366695 0.997986 1.022472 29.610807 24.477913 89.790227 8.830759 31.624272 55.119903 16.578290 36.074941 10.257474 2.385735 18.067635 13.058828 19.243283 64.834307 44.568440 16.431280
std 0.156716 0.420296 0.423629 9.093738 8.446750 5.380027 3.404027 4.928902 4.525917 5.668479 6.675206 4.698705 1.053270 2.248750 1.621328 2.706605 6.232239 5.265787 7.637823 0.420114 0.423538 9.082940 8.443146 5.366067 3.389310 4.916168 4.524887 5.660085 6.664218 4.692479 1.053237 2.249864 1.625269 2.707148 6.226622 5.248787 7.627538 0.420205 0.423705 9.070537 8.437044 5.351574 3.377722 4.901857 4.524442 5.651583 6.655166 4.686263 1.053147 2.250081 1.628589 2.708339 6.220087 5.232030 7.619354
min 0.000000 0.126321 0.132470 9.961640 6.857181 70.950912 1.252983 18.311941 43.220187 3.637414 21.485815 2.118674 0.728770 12.980786 8.370536 13.400399 48.225603 33.113882 2.338708 0.126321 0.132470 9.961640 6.857181 72.330064 1.252983 18.311941 43.220187 3.637414 21.485815 2.118674 0.728770 12.980786 8.370536 13.400399 48.225603 33.113882 2.338708 0.126321 0.132470 9.961640 6.857181 72.356322 1.252983 18.311941 43.220187 3.637414 21.485815 2.118674 0.728770 12.980786 8.370536 13.400399 48.225603 33.113882 2.338708
25% 0.000000 0.673929 0.697515 23.203165 18.539153 86.309537 6.177754 28.247865 51.547206 13.311050 30.740931 6.653427 1.720601 16.420485 11.943953 17.292063 59.529326 40.520369 10.200722 0.676205 0.699773 23.264324 18.607342 86.386111 6.168986 28.202745 51.403036 13.248788 30.646955 6.605724 1.715372 16.423140 11.933745 17.303887 59.703583 40.533768 10.251453 0.680065 0.703390 23.307794 18.644297 86.436468 6.159286 28.187875 51.262363 13.200532 30.606711 6.532543 1.714080 16.420485 11.914167 17.322912 59.782876 40.549987 10.327314
50% 0.000000 0.912747 0.940295 28.955738 23.819761 90.819435 8.288288 32.143140 55.257262 16.371699 36.267966 9.802380 2.204258 17.685476 12.963659 18.735807 65.688024 43.911769 15.479766 0.917343 0.942587 29.061296 23.905188 90.859943 8.274067 32.108420 55.129326 16.293314 36.169954 9.738629 2.203602 17.684970 12.956723 18.745824 65.783579 43.947131 15.572281 0.920815 0.948001 29.137273 24.010817 90.912271 8.251691 32.051128 54.990445 16.227010 36.041389 9.700368 2.199521 17.684197 12.948749 18.760267 65.932259 43.997637 15.646480
75% 0.000000 1.266849 1.302040 36.109114 30.238061 93.937119 11.582209 35.387315 58.866130 21.396971 41.659971 13.734197 2.745406 19.501218 14.214320 20.665840 69.497484 48.098224 22.503685 1.268148 1.301877 36.233383 30.318671 93.955966 11.525572 35.362666 58.797715 21.333613 41.562070 13.684985 2.734372 19.503419 14.214320 20.693846 69.578458 48.108341 22.527315 1.269136 1.304112 36.345667 30.459044 93.975501 11.477910 35.299957 58.752924 21.207162 41.508520 13.602566 2.730469 19.503419 14.214320 20.713638 69.719651 48.118283 22.535165
max 1.000000 2.597732 2.625885 56.832289 51.550450 98.087160 18.552325 42.359074 65.673889 28.488220 50.606465 24.496711 8.162275 28.574091 18.715944 28.366270 77.701014 58.433600 38.670000 2.597732 2.625885 56.832289 51.550450 98.087160 18.552325 42.359074 65.673889 28.488220 50.606465 24.496711 8.162275 28.574091 18.715944 28.366270 77.701014 58.433600 40.959495 2.597732 2.625885 56.832289 51.550450 98.087160 18.552325 42.359074 65.673889 28.488220 50.606465 24.496711 8.162275 28.574091 18.715944 28.366270 77.701014 58.433600 40.959495
1
data.head(10).iloc[:, 40:]
WI cli ili hh_cmnty_cli nohh_cmnty_cli wearing_mask travel_outside_state work_outside_home shop restaurant spent_time large_event public_transit anxious depressed felt_isolated worried_become_ill worried_finances tested_positive cli.1 ili.1 hh_cmnty_cli.1 nohh_cmnty_cli.1 wearing_mask.1 travel_outside_state.1 work_outside_home.1 shop.1 restaurant.1 spent_time.1 large_event.1 public_transit.1 anxious.1 depressed.1 felt_isolated.1 worried_become_ill.1 worried_finances.1 tested_positive.1 cli.2 ili.2 hh_cmnty_cli.2 nohh_cmnty_cli.2 wearing_mask.2 travel_outside_state.2 work_outside_home.2 shop.2 restaurant.2 spent_time.2 large_event.2 public_transit.2 anxious.2 depressed.2 felt_isolated.2 worried_become_ill.2 worried_finances.2 tested_positive.2
0 0.0 0.814610 0.771356 25.648907 21.242063 84.644672 13.462475 36.519841 63.139094 23.835119 44.726055 16.946929 1.716262 15.494193 12.043275 17.000647 53.439316 43.279629 19.586492 0.838995 0.807767 25.679101 21.280270 84.005294 13.467716 36.637887 63.318650 23.688882 44.385166 16.463551 1.664819 15.299228 12.051505 16.552264 53.256795 43.622728 20.151838 0.897802 0.887893 26.060544 21.503832 84.438618 13.038611 36.429119 62.434539 23.812411 43.430423 16.151527 1.602635 15.409449 12.088688 16.702086 53.991549 43.604229 20.704935
1 0.0 0.838995 0.807767 25.679101 21.280270 84.005294 13.467716 36.637887 63.318650 23.688882 44.385166 16.463551 1.664819 15.299228 12.051505 16.552264 53.256795 43.622728 20.151838 0.897802 0.887893 26.060544 21.503832 84.438618 13.038611 36.429119 62.434539 23.812411 43.430423 16.151527 1.602635 15.409449 12.088688 16.702086 53.991549 43.604229 20.704935 0.972842 0.965496 25.754087 21.016210 84.133873 12.581952 36.416557 62.024517 23.682974 43.196313 16.123386 1.641863 15.230063 11.809047 16.506973 54.185521 42.665766 21.292911
2 0.0 0.897802 0.887893 26.060544 21.503832 84.438618 13.038611 36.429119 62.434539 23.812411 43.430423 16.151527 1.602635 15.409449 12.088688 16.702086 53.991549 43.604229 20.704935 0.972842 0.965496 25.754087 21.016210 84.133873 12.581952 36.416557 62.024517 23.682974 43.196313 16.123386 1.641863 15.230063 11.809047 16.506973 54.185521 42.665766 21.292911 0.955306 0.963079 25.947015 20.941798 83.995931 12.938675 37.014578 62.116842 23.593983 43.362200 16.159971 1.677523 15.717207 12.355918 16.273294 53.637069 42.972417 21.166656
3 0.0 0.972842 0.965496 25.754087 21.016210 84.133873 12.581952 36.416557 62.024517 23.682974 43.196313 16.123386 1.641863 15.230063 11.809047 16.506973 54.185521 42.665766 21.292911 0.955306 0.963079 25.947015 20.941798 83.995931 12.938675 37.014578 62.116842 23.593983 43.362200 16.159971 1.677523 15.717207 12.355918 16.273294 53.637069 42.972417 21.166656 0.947513 0.968764 26.350501 21.109971 83.819531 12.452336 36.270021 61.294809 22.576992 42.954574 15.544373 1.578030 15.295650 12.218123 16.045504 52.446223 42.907472 19.896607
4 0.0 0.955306 0.963079 25.947015 20.941798 83.995931 12.938675 37.014578 62.116842 23.593983 43.362200 16.159971 1.677523 15.717207 12.355918 16.273294 53.637069 42.972417 21.166656 0.947513 0.968764 26.350501 21.109971 83.819531 12.452336 36.270021 61.294809 22.576992 42.954574 15.544373 1.578030 15.295650 12.218123 16.045504 52.446223 42.907472 19.896607 0.883833 0.893020 26.480624 21.003982 84.049437 12.224644 35.380198 60.664482 22.091433 43.290957 15.214655 1.641667 14.778802 12.417256 16.134238 52.560315 43.321985 20.178428
5 0.0 0.947513 0.968764 26.350501 21.109971 83.819531 12.452336 36.270021 61.294809 22.576992 42.954574 15.544373 1.578030 15.295650 12.218123 16.045504 52.446223 42.907472 19.896607 0.883833 0.893020 26.480624 21.003982 84.049437 12.224644 35.380198 60.664482 22.091433 43.290957 15.214655 1.641667 14.778802 12.417256 16.134238 52.560315 43.321985 20.178428 0.887642 0.850387 26.665874 20.987079 84.025937 12.146718 35.381223 60.531769 21.679396 43.345340 15.548976 1.873230 15.346263 12.951090 16.523990 52.185260 43.600100 18.131814
6 0.0 0.883833 0.893020 26.480624 21.003982 84.049437 12.224644 35.380198 60.664482 22.091433 43.290957 15.214655 1.641667 14.778802 12.417256 16.134238 52.560315 43.321985 20.178428 0.887642 0.850387 26.665874 20.987079 84.025937 12.146718 35.381223 60.531769 21.679396 43.345340 15.548976 1.873230 15.346263 12.951090 16.523990 52.185260 43.600100 18.131814 0.826582 0.792924 26.840360 21.262545 84.188329 12.203852 35.348494 60.334918 21.970599 43.632334 16.179562 2.077821 15.249744 13.111449 17.261584 52.544181 42.750387 15.517490
7 0.0 0.887642 0.850387 26.665874 20.987079 84.025937 12.146718 35.381223 60.531769 21.679396 43.345340 15.548976 1.873230 15.346263 12.951090 16.523990 52.185260 43.600100 18.131814 0.826582 0.792924 26.840360 21.262545 84.188329 12.203852 35.348494 60.334918 21.970599 43.632334 16.179562 2.077821 15.249744 13.111449 17.261584 52.544181 42.750387 15.517490 0.842653 0.764839 27.172910 21.569805 84.296258 11.978263 36.108227 61.573888 22.682459 44.525243 16.449065 2.011220 15.126717 13.124831 17.018251 52.621585 42.707477 15.961637
8 0.0 0.826582 0.792924 26.840360 21.262545 84.188329 12.203852 35.348494 60.334918 21.970599 43.632334 16.179562 2.077821 15.249744 13.111449 17.261584 52.544181 42.750387 15.517490 0.842653 0.764839 27.172910 21.569805 84.296258 11.978263 36.108227 61.573888 22.682459 44.525243 16.449065 2.011220 15.126717 13.124831 17.018251 52.621585 42.707477 15.961637 0.888932 0.780920 26.918955 21.543915 83.933997 11.886287 35.867644 62.040243 23.156434 44.877576 16.720348 1.997892 15.031244 12.670960 16.842015 52.491832 43.775117 13.068527
9 0.0 0.842653 0.764839 27.172910 21.569805 84.296258 11.978263 36.108227 61.573888 22.682459 44.525243 16.449065 2.011220 15.126717 13.124831 17.018251 52.621585 42.707477 15.961637 0.888932 0.780920 26.918955 21.543915 83.933997 11.886287 35.867644 62.040243 23.156434 44.877576 16.720348 1.997892 15.031244 12.670960 16.842015 52.491832 43.775117 13.068527 0.841033 0.738512 26.354452 20.937428 83.400874 11.592788 35.926601 62.091197 23.661191 44.603474 16.671175 2.022002 14.535698 12.270574 16.915207 52.402077 43.534666 15.178088
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Sklearn
import sklearn
from sklearn.feature_selection import SelectKBest
from sklearn.feature_selection import f_regression

x = data[data.columns[1:94]]
y = data[data.columns[94]]

# 特征缩放 归一化 [0, 1]
x = (x - x.min()) / (x.max() - x.min())

# 特征选择
bestfeatures = SelectKBest(score_func=f_regression) # F 统计量作为评分函数
fit = bestfeatures.fit(x,y) # 对数据进行拟合,计算评分

# 评分写为 DataFrame
dfscores = pd.DataFrame(fit.scores_)
dfcolumns = pd.DataFrame(x.columns)
featureScores = pd.concat([dfcolumns,dfscores],axis=1)
featureScores.columns = ['Specs','Score']

# 查看前 20 的特征
print(featureScores.nlargest(20,'Score'))
# 取出前 17 的特征
top_rows = featureScores.nlargest(20, 'Score').index.tolist()[:17]
print(top_rows)
                 Specs          Score
75   tested_positive.1  148069.658278
57     tested_positive   69603.872591
42        hh_cmnty_cli    9235.492094
60      hh_cmnty_cli.1    9209.019558
78      hh_cmnty_cli.2    9097.375172
43      nohh_cmnty_cli    8395.421300
61    nohh_cmnty_cli.1    8343.255927
79    nohh_cmnty_cli.2    8208.176435
40                 cli    6388.906849
58               cli.1    6374.548000
76               cli.2    6250.008702
41                 ili    5998.922880
59               ili.1    5937.588576
77               ili.2    5796.947672
92  worried_finances.2     833.613191
74  worried_finances.1     811.916460
56    worried_finances     788.076931
87    public_transit.2     686.736539
69    public_transit.1     681.562902
51      public_transit     678.834789
[75, 57, 42, 60, 78, 43, 61, 79, 40, 58, 76, 41, 59, 77, 92, 74, 56]

导入包

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Pytorch
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# 数据预处理
import numpy as np
import csv
import os

# 绘图
import matplotlib.pyplot as plt

# 可重现性
my_seed = 65472
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(my_seed)
torch.manual_seed(my_seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(my_seed)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def get_device():
''' 使用 GPU/CPU '''
return 'cuda' if torch.cuda.is_available() else 'cpu'

def plot_learning_curve(loss_record, title=''):
''' 绘制训练过程中的 loss '''
all_steps = len(loss_record['train'])
dev_steps = len(loss_record['dev'])
x_1 = range(all_steps)
x_2 = x_1[::all_steps // dev_steps]
plt.figure(figsize=(6,4))
plt.plot(x_1, loss_record['train'], c='tab:red', label='train')
plt.plot(x_2, loss_record['dev'], c='tab:cyan', label='dev')
plt.ylim(0.0, 5.0)
plt.xlabel('Training steps')
plt.ylabel('MSE loss')
plt.title('Learning curve of {}'.format(title))
plt.legend()
plt.show()

def plot_pred(dev_set, model, device, lim=35., preds=None, targets=None):
''' 绘制 DNN 的预测情况 '''
if preds is None or targets is None:
model.eval()
preds, targets = [], []
for x, y in dev_set:
x, y = x.to(device), y.to(device)
with torch.no_grad():
pred = model(x)
preds.append(pred.detach().cpu())
targets.append(y.detach().cpu())
preds = torch.cat(preds, dim=0).numpy()
targets = torch.cat(targets, dim=0).numpy()

plt.figure(figsize=(5,5))
plt.scatter(targets, preds, c='r', alpha=0.5)
plt.plot([-0.2, lim], [-0.2, lim], c='b')
plt.xlim(-0.2, lim)
plt.ylim(-0.2, lim)
plt.xlabel('ground truth value')
plt.ylabel('predicted value')
plt.title('Ground Truth v.s. Prediction')
plt.show()

1. Dataset

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class COVID19Dataset(Dataset):
'''
加载和预处理
(path, mode=,target_only=)
'''
def __init__(self,
path,
mode='train',
target_only=False):
# 模式
self.mode = mode

# 读取数据
with open(path, 'r') as fp:
data = list(csv.reader(fp))
data = np.array(data[1:])[:, 1:].astype(float)

# 特征工程
if not target_only:
features = list(range(93))
else:
# TODO
features = [75, 57, 42, 60, 78, 43, 61, 79, 40, 58, 76, 41, 59, 77, 92, 74, 56]

# 数据准备 x y
if mode == 'test':
data = data[:, features]
self.data = torch.FloatTensor(data)
else:
target = data[:, -1]
data = data[:, features]
if mode == 'train':
indices = [i for i in range(len(data)) if i%10 != 0]
elif mode == 'dev':
indices = [i for i in range(len(data)) if i%10 == 0]

self.target = torch.FloatTensor(target[indices])
self.data = torch.FloatTensor(data[indices])

# 标准化 Normalize
self.data[:, 40:] = \
(self.data[:, 40:] - self.data[:, 40:].mean(dim=0, keepdim=True)) \
/ (self.data[:, 40:].std(dim=0, keepdim=True))

# 特征维度
self.dim = self.data.shape[1]

print('Finished reading the {} set of COVID19 Dataset ({} samples found, each dim = {})'
.format(mode, len(self.data), self.dim))

def __getitem__(self, index):
# 获取一个样本数据
if self.mode in ['train', 'dev']:
return self.data[index], self.target[index]
else:
return self.data[index]

def __len__(self):
# 长度
return len(self.data)

2. DataLoader

1
2
3
4
5
6
def prep_dataloader(path, mode, batch_size, n_jobs=0, target_only=False):
''' 读取数据集,并放入数据加载器 '''
dataset = COVID19Dataset(path, mode, target_only)
dataloader = DataLoader(dataset, batch_size, num_workers=n_jobs,
shuffle=(mode=='train'), pin_memory=True, drop_last=False)
return dataloader

3. Deep Neural Network

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class NeuralNet(nn.Module):
'''
一个简单的全连接 DNN
'''
def __init__(self, input_dim):
'''
初始化 input_dim 输入维度
'''
# 父类初始化
super(NeuralNet, self).__init__()

# 网络序列
# TODO
self.net = nn.Sequential(
nn.Linear(input_dim, 16),
nn.BatchNorm1d(16),
nn.Dropout(p=0.2),
nn.ReLU(),
nn.Linear(16, 1)
)

# 损失函数
self.criterion = nn.MSELoss(reduction='mean')

def forward(self, x):
''' 前向传播 '''
return self.net(x).squeeze(1)

def cal_loss(self, pred, target):
''' 计算损失 '''
regularization_loss = 0
for param in model.parameters():
regularization_loss += torch.sum(param ** 2)
return self.criterion(pred, target) + 0.00075 * regularization_loss
# return self.criterion(pred, target)

4. Train, Validation, Test

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def dev(dv_set, model, device):
''' 计算验证集上的loss '''
model.eval()
total_loss = 0
for x, y in dv_set:
x, y = x.to(device), y.to(device)
with torch.no_grad():
pred = model(x)
mse_loss = model.cal_loss(pred, y)
# 本批次损失(每次是按批次取的,因此要乘以本批次的数量)
total_loss += mse_loss.detach().cpu().item() * len(x)

# 最后总的算平均
total_loss /= len(dv_set.dataset)
return total_loss


def train(tr_set, dv_set, model, config, device):
''' 训练模型 '''
# 定义优化器
optimizer = getattr(torch.optim, config['optimizer'])\
(model.parameters(), **config['optim_hparas'])

# 训练过程
n_epochs = config['n_epochs']
epoch = 0
early_stop_cnt = 0
min_loss = 1000.0
loss_record = {'train':[], 'dev':[]}
while epoch < n_epochs:
# 一轮训练
model.train()
for x, y in tr_set:
optimizer.zero_grad() # 清除参数
x, y = x.to(device), y.to(device)
pred = model(x) # 前向传递
loss = model.cal_loss(pred, y) # 计算损失
loss.backward() # 反向传播
optimizer.step() # 参数优化
loss_record['train'].append(loss.detach().cpu().item())

# 一轮训练后进行一次验证效果
dev_loss = dev(dv_set, model, device)
loss_record['dev'].append(dev_loss)
if dev_loss < min_loss: # 效果更好则保存模型
min_loss = dev_loss
torch.save(model.state_dict(), config['save_path'])
early_stop_cnt = 0
print('Saving model (epoch = {:4d}, loss = {:4f})'.format(epoch, dev_loss))
else:
early_stop_cnt += 1

# 轮数播报
epoch +=1
if epoch % 100 == 0:
print('Epoch {} finished.'.format(epoch))

# 持续无效则提前终止
if early_stop_cnt > config['early_stop']:
break

# 汇报
print('Finished training after {} epochs, min_loss = {}'.format(epoch, min_loss))
return min_loss, loss_record


def test(tt_set, model, device):
''' 预测结果(测试集) '''
model.eval()
preds = []
for x in tt_set:
x = x.to(device)
with torch.no_grad():
pred = model(x)
preds.append(pred.detach().cpu())
preds = torch.cat(preds, dim=0).numpy()
return preds

5. Setup Hyper-parameters

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
device = get_device()
os.makedirs('models', exist_ok=True)
target_only = True # TODO

# TODO
config = {
'n_epochs': 10000,
'batch_size': 200,
'optimizer': 'Adam',
'optim_hparas': {
'lr': 0.0005,
# 'momentum': 0.9
},
'early_stop': 1000,
'save_path': 'models/model.pth'
}

6. Load Data and Model

1
2
3
4
# 加载数据
tr_set = prep_dataloader(tr_path, 'train', config['batch_size'], target_only=target_only)
dv_set = prep_dataloader(tr_path, 'dev', config['batch_size'], target_only=target_only)
tt_set = prep_dataloader(tt_path, 'test', config['batch_size'], target_only=target_only)
Finished reading the train set of COVID19 Dataset (2430 samples found, each dim = 17)
Finished reading the dev set of COVID19 Dataset (270 samples found, each dim = 17)
Finished reading the test set of COVID19 Dataset (893 samples found, each dim = 17)
1
2
# 实例化模型
model = NeuralNet(tr_set.dataset.dim).to(device)
1
2
# 开始训练
model_loss, model_loss_record = train(tr_set, dv_set, model, config, device)
Saving model (epoch =    0, loss = 340.917373)
Saving model (epoch =    1, loss = 322.165056)
Saving model (epoch =    2, loss = 313.825131)
Saving model (epoch =    3, loss = 309.852996)
Saving model (epoch =    4, loss = 306.339627)
Saving model (epoch =    5, loss = 306.296543)
Saving model (epoch =    6, loss = 305.443566)
Saving model (epoch =    7, loss = 300.847960)
Saving model (epoch =    8, loss = 289.878357)
Saving model (epoch =    9, loss = 283.428931)
Saving model (epoch =   11, loss = 277.324653)
Saving model (epoch =   19, loss = 275.598712)
Saving model (epoch =   20, loss = 274.431580)
Saving model (epoch =   21, loss = 271.097685)
Saving model (epoch =   22, loss = 266.430409)
Saving model (epoch =   23, loss = 261.281929)
Saving model (epoch =   24, loss = 260.683929)
Saving model (epoch =   25, loss = 258.624722)
Saving model (epoch =   26, loss = 253.948218)
Saving model (epoch =   27, loss = 247.472347)
Saving model (epoch =   29, loss = 240.454434)
Saving model (epoch =   31, loss = 237.730461)
Saving model (epoch =   32, loss = 233.930249)
Saving model (epoch =   33, loss = 228.276315)
Saving model (epoch =   34, loss = 224.006245)
Saving model (epoch =   35, loss = 219.322504)
Saving model (epoch =   36, loss = 207.611922)
Saving model (epoch =   37, loss = 200.482134)
Saving model (epoch =   38, loss = 196.714348)
Saving model (epoch =   41, loss = 192.946156)
Saving model (epoch =   42, loss = 187.776669)
Saving model (epoch =   43, loss = 187.133918)
Saving model (epoch =   44, loss = 186.150606)
Saving model (epoch =   45, loss = 183.542729)
Saving model (epoch =   46, loss = 177.969274)
Saving model (epoch =   47, loss = 165.223580)
Saving model (epoch =   48, loss = 161.321226)
Saving model (epoch =   49, loss = 155.550834)
Saving model (epoch =   53, loss = 149.358119)
Saving model (epoch =   54, loss = 142.696106)
Saving model (epoch =   55, loss = 139.488191)
Saving model (epoch =   56, loss = 137.709078)
Saving model (epoch =   57, loss = 132.763100)
Saving model (epoch =   60, loss = 123.616633)
Saving model (epoch =   61, loss = 121.146074)
Saving model (epoch =   64, loss = 112.364984)
Saving model (epoch =   65, loss = 100.750797)
Saving model (epoch =   72, loss = 92.428780)
Saving model (epoch =   75, loss = 78.373004)
Saving model (epoch =   80, loss = 68.048881)
Saving model (epoch =   82, loss = 53.388715)
Saving model (epoch =   85, loss = 42.509547)
Saving model (epoch =   94, loss = 39.354231)
Saving model (epoch =   98, loss = 26.928447)
Epoch 100 finished.
Saving model (epoch =  101, loss = 22.667163)
Saving model (epoch =  106, loss = 9.036212)
Saving model (epoch =  113, loss = 7.133554)
Saving model (epoch =  141, loss = 5.964343)
Saving model (epoch =  160, loss = 5.780281)
Saving model (epoch =  184, loss = 5.398957)
Epoch 200 finished.
Saving model (epoch =  203, loss = 4.893500)
Saving model (epoch =  216, loss = 4.655919)
Saving model (epoch =  218, loss = 4.601365)
Saving model (epoch =  227, loss = 4.375972)
Saving model (epoch =  233, loss = 4.232011)
Saving model (epoch =  238, loss = 4.214206)
Saving model (epoch =  245, loss = 3.868370)
Saving model (epoch =  261, loss = 3.665027)
Saving model (epoch =  266, loss = 3.496048)
Saving model (epoch =  276, loss = 3.405258)
Saving model (epoch =  282, loss = 3.251506)
Saving model (epoch =  284, loss = 3.161817)
Epoch 300 finished.
Saving model (epoch =  301, loss = 3.101231)
Saving model (epoch =  306, loss = 2.750548)
Saving model (epoch =  307, loss = 2.730497)
Saving model (epoch =  318, loss = 2.705987)
Saving model (epoch =  321, loss = 2.625461)
Saving model (epoch =  326, loss = 2.563876)
Saving model (epoch =  342, loss = 2.365493)
Saving model (epoch =  343, loss = 2.225936)
Saving model (epoch =  348, loss = 2.199931)
Saving model (epoch =  359, loss = 2.137739)
Saving model (epoch =  360, loss = 2.051806)
Saving model (epoch =  362, loss = 1.994729)
Saving model (epoch =  364, loss = 1.951313)
Saving model (epoch =  370, loss = 1.903423)
Saving model (epoch =  372, loss = 1.817203)
Saving model (epoch =  386, loss = 1.771863)
Saving model (epoch =  394, loss = 1.617613)
Epoch 400 finished.
Saving model (epoch =  411, loss = 1.475305)
Saving model (epoch =  414, loss = 1.455347)
Saving model (epoch =  415, loss = 1.413618)
Saving model (epoch =  421, loss = 1.404032)
Saving model (epoch =  425, loss = 1.342632)
Saving model (epoch =  433, loss = 1.317732)
Saving model (epoch =  440, loss = 1.289960)
Saving model (epoch =  450, loss = 1.251974)
Saving model (epoch =  451, loss = 1.186455)
Saving model (epoch =  465, loss = 1.124176)
Saving model (epoch =  480, loss = 1.115070)
Saving model (epoch =  482, loss = 1.063177)
Saving model (epoch =  495, loss = 1.020159)
Epoch 500 finished.
Saving model (epoch =  505, loss = 0.998222)
Saving model (epoch =  525, loss = 0.996696)
Saving model (epoch =  527, loss = 0.974490)
Saving model (epoch =  528, loss = 0.946199)
Saving model (epoch =  548, loss = 0.934515)
Saving model (epoch =  549, loss = 0.930782)
Saving model (epoch =  552, loss = 0.918216)
Saving model (epoch =  562, loss = 0.888980)
Saving model (epoch =  586, loss = 0.885096)
Epoch 600 finished.
Saving model (epoch =  605, loss = 0.879943)
Saving model (epoch =  619, loss = 0.876455)
Saving model (epoch =  652, loss = 0.871618)
Saving model (epoch =  657, loss = 0.862073)
Saving model (epoch =  664, loss = 0.855746)
Epoch 700 finished.
Saving model (epoch =  736, loss = 0.848672)
Epoch 800 finished.
Saving model (epoch =  838, loss = 0.833972)
Epoch 900 finished.
Epoch 1000 finished.
Epoch 1100 finished.
Epoch 1200 finished.
Epoch 1300 finished.
Epoch 1400 finished.
Epoch 1500 finished.
Epoch 1600 finished.
Epoch 1700 finished.
Epoch 1800 finished.
Finished training after 1840 epochs, min_loss = 0.833972383428503
1
2
# 绘制训练过程中 loss 的情况
plot_learning_curve(model_loss_record, 'deep model')


png

7. Predict

1
del model
1
2
3
4
5
6
7
# 新实例化模型,把参数加载进来
model = NeuralNet(tr_set.dataset.dim).to(device)
ckpt = torch.load(config['save_path'], map_location='cpu', weights_only=True)
model.load_state_dict(ckpt)

# 预测结果
plot_pred(dv_set, model, device)


png

8. Save result to file

1
2
3
4
5
6
7
8
9
10
def save_pred(preds, file):
''' 保存预测结果到csv文件 '''
with open(file, 'w') as fp:
writer = csv.writer(fp)
writer.writerow(['id', 'tested_positive'])
for i, p in enumerate(preds):
writer.writerow([i, p])

preds = test(tt_set, model, device)
save_pred(preds, 'preds.csv')

Score

  • Public: 0.89359
  • Private: 0.90019

*9. Hints

Simple baseline

  • Run sample code

Medium baseline

  • Feature selection: 40 states + 2 tested_positive

Strong baseline

  • Feature selection
  • DNN architecture ( layers , dimension, activation function )
  • Training ( mini-batch, optimizer, learning rate )
  • L2 regularization
  • There are some mistakes in code.
1