Skip to content

Commit

Permalink
MES floating values fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
nischalhp committed Jul 31, 2016
1 parent e456c30 commit c818a4b
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions poget/analytics/ml/linear_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import shutil
from pyspark import SparkConf, SparkContext,SQLContext
from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD
from decimal import Decimal

from poget import LOGGER

Expand Down Expand Up @@ -30,17 +31,16 @@ def test_train(self, df, target, train_split, test_split):
zipped = y_train.zip(X_train)
train_data = zipped.map(lambda x: LabeledPoint(x[0], x[1]))

linear_model = LinearRegressionWithSGD.train(train_data)
linear_model = LinearRegressionWithSGD.train(train_data, intercept=True)

X_test = test.select(*feature_columns).map(lambda x: list(x))
y_test = test.select(target).map(lambda x: x[0])

prediction = X_test.map(lambda lp: (float(linear_model.predict(lp))))
prediction_and_label = prediction.zip(y_test)
label_and_prediction = prediction.zip(y_test)
val = label_and_prediction.map(lambda vp: (Decimal(vp[0]) - Decimal(vp[1])) ** 2).reduce(lambda x, y: x + y)/label_and_prediction.count()

MSE = prediction_and_label.map(lambda (v, p): (v - p) ** 2).reduce(lambda x, y: x + y) / prediction_and_label.count()

LOGGER.info(prediction_and_label.map(lambda (labelAndPred[0], labelAndPred[1]): labelAndPred[0] == labelAndPred[1]).mean())
LOGGER.info(val)
except Exception as e:
raise e

Expand All @@ -59,7 +59,7 @@ def train(self, df, target):
zipped = y_train.zip(X_train)
train_data = zipped.map(lambda x: LabeledPoint(x[0], x[1]))

linear_model = LinearRegressionWithSGD.train(train_data)
linear_model = LinearRegressionWithSGD.train(train_data, intercept=True)

self.model = linear_model

Expand Down

0 comments on commit c818a4b

Please sign in to comment.