aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Auvolat <alex.auvolat@ens.fr>2015-07-23 10:13:28 -0400
committerAlex Auvolat <alex.auvolat@ens.fr>2015-07-23 10:13:28 -0400
commit9313c8bb7f62b14f8c15d2119ae678641f751dbd (patch)
tree646be5562b40d1203fc32764f956c03e04c99dde
parentdd8ae5ea8ed0c7cb1a7880b1e1887c6e23cdf910 (diff)
downloadtaxi-9313c8bb7f62b14f8c15d2119ae678641f751dbd.tar.gz
taxi-9313c8bb7f62b14f8c15d2119ae678641f751dbd.zip
Can do test with batches of size >1
-rw-r--r--ext_test.py13
-rw-r--r--model/mlp.py2
2 files changed, 8 insertions, 7 deletions
diff --git a/ext_test.py b/ext_test.py
index 3af637b..9e64223 100644
--- a/ext_test.py
+++ b/ext_test.py
@@ -61,12 +61,13 @@ class RunOnTest(SimpleExtension):
for d in self.test_stream.get_epoch_iterator(as_dict=True):
input_values = [d[k.name] for k in self.inputs]
output_values = self.function(*input_values)
- if 'destination' in self.outputs:
- destination = output_values[self.outputs.index('destination')]
- dest_outcsv.writerow([d['trip_id'][0], destination[0, 0], destination[0, 1]])
- if 'duration' in self.outputs:
- duration = output_values[self.outputs.index('duration')]
- time_outcsv.writerow([d['trip_id'][0], int(round(duration[0]))])
+ for i in range(d['trip_id'].shape[0]):
+ if 'destination' in self.outputs:
+ destination = output_values[self.outputs.index('destination')]
+ dest_outcsv.writerow([d['trip_id'][i], destination[i, 0], destination[i, 1]])
+ if 'duration' in self.outputs:
+ duration = output_values[self.outputs.index('duration')]
+ time_outcsv.writerow([d['trip_id'][i], int(round(duration[i]))])
if 'destination' in self.outputs:
dest_outfile.close()
diff --git a/model/mlp.py b/model/mlp.py
index 7d04c82..d24b2cc 100644
--- a/model/mlp.py
+++ b/model/mlp.py
@@ -106,7 +106,7 @@ class Stream(object):
stream = transformers.taxi_add_first_last_len(stream, self.config.n_begin_end_pts)
stream = transformers.taxi_remove_test_only_clients(stream)
- return Batch(stream, iteration_scheme=ConstantScheme(1))
+ return Batch(stream, iteration_scheme=ConstantScheme(self.config.batch_size))
def inputs(self):
return {'call_type': tensor.bvector('call_type'),