aboutsummaryrefslogtreecommitdiff
path: root/main.py
diff options
context:
space:
mode:
Diffstat (limited to 'main.py')
-rw-r--r--main.py31
1 files changed, 21 insertions, 10 deletions
diff --git a/main.py b/main.py
index e288b9b..e505348 100644
--- a/main.py
+++ b/main.py
@@ -8,7 +8,7 @@ from ctc import CTC
from blocks.initialization import IsotropicGaussian, Constant
from fuel.datasets import IterableDataset
from fuel.streams import DataStream
-from blocks.algorithms import (GradientDescent, Scale,
+from blocks.algorithms import (GradientDescent, Scale, AdaDelta, RemoveNotFinite,
StepClipping, CompositeRule)
from blocks.extensions.monitoring import TrainingDataMonitoring, DataStreamMonitoring
from blocks.main_loop import MainLoop
@@ -18,10 +18,12 @@ from blocks.graph import ComputationGraph
from dummy_dataset import setup_datastream
+from edit_distance import batch_edit_distance
+
floatX = theano.config.floatX
-n_epochs = 200
+n_epochs = 10000
num_input_classes = 5
h_dim = 20
rec_dim = 20
@@ -63,6 +65,10 @@ y_hat = tensor.nnet.softmax(
).reshape((y_hat_pre.shape[0], y_hat_pre.shape[1], -1))
y_hat.name = 'y_hat'
+#y_hat = theano.printing.Print('y_hat')(y_hat)
+#y = theano.printing.Print('y')(y)
+#y_mask = theano.printing.Print('y_mask')(y_mask)
+
y_hat_mask = input_mask
# Cost
@@ -71,6 +77,10 @@ cost = CTC().apply(y, y_hat, y_len, y_hat_mask).mean()
cost.name = 'CTC'
dl, dl_length = CTC().best_path_decoding(y_hat, y_hat_mask)
+
+edit_distance = batch_edit_distance(dl.T.astype('int32'), dl_length, y.T.astype('int32'), y_len.astype('int32')).mean()
+edit_distance.name = 'edit_distance'
+
L = y.shape[0]
B = y.shape[1]
dl = dl[:L, :]
@@ -80,6 +90,7 @@ is_error = tensor.switch(is_error.sum(axis=0), tensor.ones((B,)), tensor.neq(y_l
error_rate = is_error.mean()
error_rate.name = 'error_rate'
+
# Initialization
for brick in [input_to_h, pre_lstm, lstm, rec_to_o]:
brick.weights_init = IsotropicGaussian(0.01)
@@ -87,23 +98,23 @@ for brick in [input_to_h, pre_lstm, lstm, rec_to_o]:
brick.initialize()
print('Bulding DataStream ...')
-ds, stream = setup_datastream(batch_size=10,
- nb_examples=1000, rng_seed=123,
- min_out_len=10, max_out_len=20)
-valid_ds, valid_stream = setup_datastream(batch_size=10,
+ds, stream = setup_datastream(batch_size=100,
+ nb_examples=10000, rng_seed=123,
+ min_out_len=5, max_out_len=10)
+valid_ds, valid_stream = setup_datastream(batch_size=100,
nb_examples=1000, rng_seed=456,
- min_out_len=10, max_out_len=20)
+ min_out_len=5, max_out_len=10)
print('Bulding training process...')
algorithm = GradientDescent(cost=cost,
parameters=ComputationGraph(cost).parameters,
- step_rule=CompositeRule([StepClipping(10.0),
- Scale(0.02)]))
+ step_rule=CompositeRule([RemoveNotFinite(), AdaDelta()]))
+ # CompositeRule([StepClipping(10.0), Scale(0.02)]))
monitor_cost = TrainingDataMonitoring([cost, error_rate],
prefix="train",
after_epoch=True)
-monitor_valid = DataStreamMonitoring([cost, error_rate],
+monitor_valid = DataStreamMonitoring([cost, error_rate, edit_distance],
data_stream=valid_stream,
prefix="valid",
after_epoch=True)