blob: 0ce6b29f20c082f62f17f31f22d7f5bf680ae4d8 (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
|
from blocks.bricks import application, Identity
import error
from model.mlp import FFMLP, Stream
class Model(FFMLP):
def __init__(self, config, **kwargs):
super(Model, self).__init__(config, output_layer=Identity, **kwargs)
self.inputs.append('input_time')
@application(outputs=['duration'])
def predict(self, **kwargs):
outputs = super(Model, self).predict(**kwargs)
return kwargs['input_time'] + self.config.exp_base ** outputs
@predict.property('inputs')
def predict_inputs(self):
return self.inputs
@application(outputs=['cost'])
def cost(self, **kwargs):
y_hat = self.predict(**kwargs)
y = kwargs['travel_time']
return error.rmsle(y_hat.flatten(), y.flatten())
@cost.property('inputs')
def cost_inputs(self):
return self.inputs + ['travel_time']
|