Skip to content

Commit

Permalink
Process minibatch result with squeeze.
Browse files Browse the repository at this point in the history
  • Loading branch information
huaidong.xiong committed Oct 19, 2021
1 parent edf19c8 commit 64327e3
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
9 changes: 8 additions & 1 deletion python/mindalpha/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,12 +409,19 @@ def preprocess_minibatch(self, minibatch):
def process_minibatch_result(self, minibatch, result):
import pandas as pd
if result is None:
result = [0.0] * len(minibatch)
result = pd.Series([0.0] * len(minibatch))
if len(result) != len(minibatch):
message = "result length (%d) and " % len(result)
message += "minibatch size (%d) mismatch" % len(minibatch)
raise RuntimeError(message)
if not isinstance(result, pd.Series):
if len(result.reshape(-1)) == len(minibatch):
result = result.reshape(-1)
else:
message = "result can not be converted to pandas series; "
message += "result.shape: {}, ".format(result.shape)
message += "minibatch_size: {}".format(len(minibatch))
raise RuntimeError(message)
result = pd.Series(result)
return result

Expand Down
9 changes: 8 additions & 1 deletion python/mindalpha/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,19 @@ def process_minibatch_result(self, minibatch, result):
import pandas as pd
minibatch_size = len(minibatch[self.input_label_column_index])
if result is None:
result = [0.0] * minibatch_size
result = pd.Series([0.0] * minibatch_size)
if len(result) != minibatch_size:
message = "result length (%d) and " % len(result)
message += "minibatch size (%d) mismatch" % minibatch_size
raise RuntimeError(message)
if not isinstance(result, pd.Series):
if len(result.reshape(-1)) == minibatch_size:
result = result.reshape(-1)
else:
message = "result can not be converted to pandas series; "
message += "result.shape: {}, ".format(result.shape)
message += "minibatch_size: {}".format(minibatch_size)
raise RuntimeError(message)
result = pd.Series(result)
return result

Expand Down

0 comments on commit 64327e3

Please sign in to comment.