summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-07-20 13:57:49 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-07-20 13:57:49 -0400
commit99c37ecef356c20673b4ecd5030749e3a6abcf7a (patch)
tree3289721d7882d811f2584ee9701255cb2cd694c1 /train.py
parent7bf692d9ae344ccef044923f131f5ce8de85b0b4 (diff)
downloadtext-rnn-99c37ecef356c20673b4ecd5030749e3a6abcf7a.tar.gz
text-rnn-99c37ecef356c20673b4ecd5030749e3a6abcf7a.zip
New model CCHLSTM ; other models broken.
Please enter the commit message for your changes. Lines starting
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py51
1 files changed, 28 insertions, 23 deletions
diff --git a/train.py b/train.py
index 61f6663..58bff1e 100755
--- a/train.py
+++ b/train.py
@@ -14,13 +14,18 @@ from theano.tensor.shared_randomstreams import RandomStreams
from blocks.serialization import load_parameter_values, secure_dump, BRICK_DELIMITER
from blocks.extensions import Printing, SimpleExtension
from blocks.extensions.monitoring import DataStreamMonitoring, TrainingDataMonitoring
-from blocks.extras.extensions.plot import Plot
from blocks.extensions.saveload import Checkpoint, Load
from blocks.graph import ComputationGraph
from blocks.main_loop import MainLoop
from blocks.model import Model
from blocks.algorithms import GradientDescent, StepRule, CompositeRule
+try:
+ from blocks.extras.extensions.plot import Plot
+ plot_avail = True
+except ImportError:
+ plot_avail = False
+
import datastream
from paramsaveload import SaveLoadParams
from gentext import GenText
@@ -63,18 +68,34 @@ class ResetStates(SimpleExtension):
def train_model(m, train_stream, dump_path=None):
# Define the model
- model = Model(m.cost)
+ model = Model(m.sgd_cost)
- cg = ComputationGraph(m.cost_reg)
- algorithm = GradientDescent(cost=m.cost_reg,
+ cg = ComputationGraph(m.sgd_cost)
+ algorithm = GradientDescent(cost=m.sgd_cost,
step_rule=CompositeRule([
ElementwiseRemoveNotFinite(),
config.step_rule]),
- params=cg.parameters)
+ parameters=cg.parameters)
algorithm.add_updates(m.states)
- extensions = []
+ monitor_vars = [v for p in m.monitor_vars for v in p]
+ extensions = [
+ TrainingDataMonitoring(
+ monitor_vars,
+ prefix='train', every_n_epochs=1),
+ Printing(every_n_epochs=1, after_epoch=False),
+
+ ResetStates([v for v, _ in m.states], after_epoch=True)
+ ]
+ if plot_avail:
+ plot_channels = [['train_' + v.name for v in p] for p in m.monitor_vars]
+ extensions.append(
+ Plot(document='text_'+model_name+'_'+config.param_desc,
+ channels=plot_channels,
+ server_url='http://eos6:5006/',
+ every_n_epochs=1, after_epoch=False)
+ )
if config.save_freq is not None and dump_path is not None:
extensions.append(
SaveLoadParams(path=dump_path+'.pkl',
@@ -105,19 +126,7 @@ def train_model(m, train_stream, dump_path=None):
model=model,
data_stream=train_stream,
algorithm=algorithm,
- extensions=extensions + [
- TrainingDataMonitoring(
- [m.cost_reg, m.error_rate_reg, m.cost, m.error_rate],
- prefix='train', every_n_epochs=1),
- Printing(every_n_epochs=1, after_epoch=False),
- Plot(document='text_'+model_name+'_'+config.param_desc,
- channels=[['train_cost', 'train_cost_reg'],
- ['train_error_rate', 'train_error_rate_reg']],
- server_url='http://eos21:4201/',
- every_n_epochs=1, after_epoch=False),
-
- ResetStates([v for v, _ in m.states], after_epoch=True)
- ]
+ extensions=extensions
)
main_loop.run()
@@ -131,10 +140,6 @@ if __name__ == "__main__":
# Build model
m = config.Model()
- m.cost.name = 'cost'
- m.cost_reg.name = 'cost_reg'
- m.error_rate.name = 'error_rate'
- m.error_rate_reg.name = 'error_rate_reg'
m.pred.name = 'pred'
# Train the model