ML2021 - HW 4
HW4: Speaker Prediction
- goal: learn to use transformer
-
赛事:https://www.kaggle.com/competitions/ml2021spring-hw4/overview
-
Baselines:
- Easy: Run sample code and know how to use transformer.
- Medium: Know how to adjust parameters of transformer.
- Hard: Construct conformer which is a variety of transformer.
-
Other links
1. Data
Dataset
1 | import os |
DataLoader
切分:
-
90% train
-
10% validation
创建 DataLoader 用于 iterate
1 | import torch |
2. Model
-
TransformerEncoderLayer:
- Base transformer encoder layer in Attention Is All You Need
- Parameters:
-
d_model: the number of expected features of the input (required).
-
nhead: the number of heads of the multiheadattention models (required).
-
dim_feedforward: the dimension of the feedforward network model (default=2048).
-
dropout: the dropout value (default=0.1).
-
activation: the activation function of intermediate layer, relu or gelu (default=relu).
-
-
TransformerEncoder:
- TransformerEncoder is a stack of N transformer encoder layers
- Parameters:
-
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
-
num_layers: the number of sub-encoder-layers in the encoder (required).
-
norm: the layer normalization component (optional).
-
1 | import torch |
3. Learning Rate Schedule
-
对于 transformer 架构,学习率的设计与以往不同。
-
warmup 很重要。
-
起初,lr 设置 0
-
lr 从 0 逐渐线性增长到 initial lr
-
1 | import math |
4. Model Function
model forward function
1 | import torch |
5. Validate
计算在 valid set 上的 accuracy
1 | from tqdm import tqdm |
6. Main Function
1 | from tqdm import tqdm |
[Info]: Use cpu now!
[Info]: Finish loading data!
[Info]: Finish creating model!
Train: 0% 0/2000 [00:00<?, ? step/s]
7. Inference
Dataset of inference
1 | import os |