-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvae_main.py
More file actions
57 lines (48 loc) · 2.36 KB
/
vae_main.py
File metadata and controls
57 lines (48 loc) · 2.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import sys
import torch
from datasets import Reuters8Dataset, IMDBDataset, Vocabulary, EmbeddingDataset
from training.pretrain_vae import pretrain
from models import VAEEncoder, Generator2
def run_vae_pretrain(opts):
verbose = opts['verbose']
if opts['dataset'] == 'reuters8':
vocab = Vocabulary.from_files([opts['train_dataset_path'],
opts['val_dataset_path'], opts['test_dataset_path']])
if verbose:
print(' [*] Vocabulary built.')
train_dataset = Reuters8Dataset(
opts['train_dataset_path'], opts['label_path'], vocab)
train_dataset = EmbeddingDataset(train_dataset, train_dataset.vocab)
if verbose:
print(' [*] Train dataset built.')
val_dataset = Reuters8Dataset(
opts['val_dataset_path'], opts['label_path'], vocab)
val_dataset = EmbeddingDataset(val_dataset, val_dataset.vocab)
if verbose:
print(' [*] Validation dataset built.')
test_dataset = Reuters8Dataset(
opts['test_dataset_path'], opts['label_path'], vocab)
test_dataset = EmbeddingDataset(test_dataset, test_dataset.vocab)
if verbose:
print(' [*] Test dataset built.')
elif opts['dataset'] == 'imdb':
train_dataset, val_dataset, test_dataset = IMDBDataset.full_split(
opts['train_dataset_path'])
train_dataset = EmbeddingDataset(train_dataset, train_dataset.vocab)
if verbose:
print(' [*] Train dataset built.')
val_dataset = EmbeddingDataset(val_dataset, val_dataset.vocab)
if verbose:
print(' [*] Validation dataset built.')
test_dataset = EmbeddingDataset(test_dataset, test_dataset.vocab)
if verbose:
print(' [*] Test dataset built.')
tensorboard = not opts['disable_tensorboard']
load_ckpt = bool(opts['resume'])
if opts['save_checkpoint_only']:
load_ckpt = False
pretrain(train_dataset, val_dataset=val_dataset, test_dataset=test_dataset, emb_size=opts['emb_size'],
lr=opts['lr'], num_epochs=opts['num_epochs'], num_workers=opts['num_workers'],
batch_size=opts['batch_size'], device=opts['device'], verbose=verbose,
tensorboard=tensorboard, tensorboard_dir=opts['tensorboard_dir'], should_load_ckpt=load_ckpt,
ckpt_dir=opts['resume'])