aboutsummaryrefslogtreecommitdiff
path: root/model/rnn.py
diff options
context:
space:
mode:
authorÉtienne Simon <esimon@esimon.eu>2015-07-23 22:15:08 -0400
committerÉtienne Simon <esimon@esimon.eu>2015-07-23 22:15:08 -0400
commit389bfd3637dfb523a3e4194c7281a0c538166546 (patch)
tree9efb9b02fce42cff0b4b22a510203726afeaf076 /model/rnn.py
parent13fc171f60ae1981c7ad4f2a302a8a85c29addc5 (diff)
downloadtaxi-389bfd3637dfb523a3e4194c7281a0c538166546.tar.gz
taxi-389bfd3637dfb523a3e4194c7281a0c538166546.zip
Fix RNN Stream without tvt
Diffstat (limited to 'model/rnn.py')
-rw-r--r--model/rnn.py2
1 files changed, 2 insertions, 0 deletions
diff --git a/model/rnn.py b/model/rnn.py
index 7bdba3b..b4c6550 100644
--- a/model/rnn.py
+++ b/model/rnn.py
@@ -155,6 +155,8 @@ class Stream(object):
stream = transformers.TaxiExcludeEmptyTrips(stream)
stream = transformers.taxi_add_datetime(stream)
+ if not data.tvt:
+ stream = transformers.add_destination(stream)
stream = transformers.Select(stream, tuple(v for v in req_vars if not v.endswith('_mask')))
stream = transformers.balanced_batch(stream, key='latitude', batch_size=self.config.batch_size, batch_sort_size=self.config.batch_sort_size)