summaryrefslogtreecommitdiff
path: root/train.py
blob: 2c4be184f439dfe0ce95f451abdcf45eeb42f932 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#!/usr/bin/env python2

import logging
import sys
import importlib

logging.basicConfig(level='INFO')
logger = logging.getLogger(__name__)

import theano

from blocks.extensions import Printing, SimpleExtension, FinishAfter, ProgressBar
from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring

from blocks.graph import ComputationGraph
from blocks.main_loop import MainLoop
from blocks.model import Model
from blocks.algorithms import GradientDescent

try:
    from blocks.extras.extensions.plot import Plot
    plot_avail = True
except ImportError:
    plot_avail = False
    logger.warning('Plotting extension not available')


import datastream
from paramsaveload import SaveLoadParams
from gentext import GenText


sys.setrecursionlimit(500000)


class ResetStates(SimpleExtension):
    def __init__(self, state_vars, **kwargs):
        super(ResetStates, self).__init__(**kwargs)

        self.f = theano.function(
            inputs=[], outputs=[],
            updates=[(v, v.zeros_like()) for v in state_vars])

    def do(self, which_callback, *args):
        self.f()

if __name__ == "__main__":
    if len(sys.argv) < 2:
        print >> sys.stderr, 'Usage: %s [options] config' % sys.argv[0]
        sys.exit(1)
    model_name = sys.argv[-1]
    config = importlib.import_module('.%s' % model_name, 'config')

    # Build datastream
    train_stream = datastream.setup_datastream(config.dataset,
                                               config.num_seqs,
                                               config.seq_len,
                                               config.seq_div_size)

    # Build model
    m = config.Model(config)

    # Train the model
    cg = Model(m.sgd_cost)
    algorithm = GradientDescent(cost=m.sgd_cost,
                                step_rule=config.step_rule,
                                parameters=cg.parameters)

    algorithm.add_updates(m.states)

    monitor_vars = list(set(v for p in m.monitor_vars for v in p))
    extensions = [
            ProgressBar(),
            TrainingDataMonitoring(
                monitor_vars,
                prefix='train', every_n_batches=config.monitor_freq),
            Printing(every_n_batches=config.monitor_freq, after_epoch=False),

            ResetStates([v for v, _ in m.states], after_epoch=True)
    ]
    if plot_avail:
        plot_channels = [['train_' + v.name for v in p] for p in m.monitor_vars]
        extensions.append(
            Plot(document='text_'+model_name,
                 channels=plot_channels,
                 # server_url='http://localhost:5006',
                 every_n_batches=config.monitor_freq)
        )

    if config.save_freq is not None and not '--nosave' in sys.argv:
        extensions.append(
            SaveLoadParams(path='params/%s.pkl'%model_name,
                           model=cg,
                           before_training=(not '--noload' in sys.argv),
                           after_training=True,
                           every_n_batches=config.save_freq)
        )

    if config.sample_freq is not None:
        extensions.append(
            GenText(m, config.sample_init,
                    config.sample_len, config.sample_temperature,
                    before_training=True,
                    every_n_batches=config.sample_freq)
        )

    main_loop = MainLoop(
        model=cg,
        data_stream=train_stream,
        algorithm=algorithm,
        extensions=extensions
    )
    main_loop.run()
    main_loop.profile.report()



#  vim: set sts=4 ts=4 sw=4 tw=0 et :