diff options
author | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-23 10:13:28 -0400 |
---|---|---|
committer | Alex Auvolat <alex.auvolat@ens.fr> | 2015-07-23 10:13:28 -0400 |
commit | 9313c8bb7f62b14f8c15d2119ae678641f751dbd (patch) | |
tree | 646be5562b40d1203fc32764f956c03e04c99dde | |
parent | dd8ae5ea8ed0c7cb1a7880b1e1887c6e23cdf910 (diff) | |
download | taxi-9313c8bb7f62b14f8c15d2119ae678641f751dbd.tar.gz taxi-9313c8bb7f62b14f8c15d2119ae678641f751dbd.zip |
Can do test with batches of size >1
-rw-r--r-- | ext_test.py | 13 | ||||
-rw-r--r-- | model/mlp.py | 2 |
2 files changed, 8 insertions, 7 deletions
diff --git a/ext_test.py b/ext_test.py index 3af637b..9e64223 100644 --- a/ext_test.py +++ b/ext_test.py @@ -61,12 +61,13 @@ class RunOnTest(SimpleExtension): for d in self.test_stream.get_epoch_iterator(as_dict=True): input_values = [d[k.name] for k in self.inputs] output_values = self.function(*input_values) - if 'destination' in self.outputs: - destination = output_values[self.outputs.index('destination')] - dest_outcsv.writerow([d['trip_id'][0], destination[0, 0], destination[0, 1]]) - if 'duration' in self.outputs: - duration = output_values[self.outputs.index('duration')] - time_outcsv.writerow([d['trip_id'][0], int(round(duration[0]))]) + for i in range(d['trip_id'].shape[0]): + if 'destination' in self.outputs: + destination = output_values[self.outputs.index('destination')] + dest_outcsv.writerow([d['trip_id'][i], destination[i, 0], destination[i, 1]]) + if 'duration' in self.outputs: + duration = output_values[self.outputs.index('duration')] + time_outcsv.writerow([d['trip_id'][i], int(round(duration[i]))]) if 'destination' in self.outputs: dest_outfile.close() diff --git a/model/mlp.py b/model/mlp.py index 7d04c82..d24b2cc 100644 --- a/model/mlp.py +++ b/model/mlp.py @@ -106,7 +106,7 @@ class Stream(object): stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts) stream = transformers.taxi_remove_test_only_clients(stream) - return Batch(stream, iteration_scheme=ConstantScheme(1)) + return Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size)) def inputs(self): return {'call_type': tensor.bvector('call_type'), |