Added function show_model_stats().
This commit is contained in:
@@ -240,8 +240,10 @@
|
|||||||
" X = X + x_shift\n",
|
" X = X + x_shift\n",
|
||||||
" y = y + y_shift\n",
|
" y = y + y_shift\n",
|
||||||
" \n",
|
" \n",
|
||||||
" X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,\n",
|
" X_train, X_test, y_train, y_test = train_test_split(\n",
|
||||||
" random_state=random)\n",
|
" X, y, \n",
|
||||||
|
" test_size=0.2, \n",
|
||||||
|
" random_state=random)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" model = LinearRegression()\n",
|
" model = LinearRegression()\n",
|
||||||
" model.fit(X_train, y_train)\n",
|
" model.fit(X_train, y_train)\n",
|
||||||
@@ -249,17 +251,8 @@
|
|||||||
" \n",
|
" \n",
|
||||||
" m = model.coef_[0][0]\n",
|
" m = model.coef_[0][0]\n",
|
||||||
" b = model.intercept_[0]\n",
|
" b = model.intercept_[0]\n",
|
||||||
" print('+----------------------+')\n",
|
" label = '%s log regression line for %s' % (color, self.language)\n",
|
||||||
" print('%s regression line for %s' % (color, self.language))\n",
|
" show_model_stats(m, b, y_test, y_pred, label)\n",
|
||||||
" print('coefficient = %0.2f' % m)\n",
|
|
||||||
" print('intercept = %0.2f' % b)\n",
|
|
||||||
" rmse = root_mean_squared_error(y_test, y_pred)\n",
|
|
||||||
" print('rmse = %0.2f' % rmse)\n",
|
|
||||||
" r2 = r2_score(y_test, y_pred)\n",
|
|
||||||
" print('r2 score = %0.2f' % r2)\n",
|
|
||||||
" print('sample predictions:')\n",
|
|
||||||
" print(y_pred[3:6])\n",
|
|
||||||
" print('+----------------------+')\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
" plt.figure(self.canvas)\n",
|
" plt.figure(self.canvas)\n",
|
||||||
" plt.plot(X_test, y_pred, color=color, label=name, linestyle=style)\n",
|
" plt.plot(X_test, y_pred, color=color, label=name, linestyle=style)\n",
|
||||||
@@ -282,17 +275,9 @@
|
|||||||
"\n",
|
"\n",
|
||||||
" m = model.coef_[0][0]\n",
|
" m = model.coef_[0][0]\n",
|
||||||
" b = model.intercept_[0]\n",
|
" b = model.intercept_[0]\n",
|
||||||
" print('+----------------------+')\n",
|
" label = '%s log regression line for %s' % (color, self.language)\n",
|
||||||
" print('%s log regression line for %s' % (color, self.language))\n",
|
" show_model_stats(m, b, y, y_pred, label)\n",
|
||||||
" print('coefficient = %0.2f' % m)\n",
|
"\n",
|
||||||
" print('intercept = %0.2f' % b)\n",
|
|
||||||
" rmse = root_mean_squared_error(y, y_pred)\n",
|
|
||||||
" print('rmse = %0.2f' % rmse)\n",
|
|
||||||
" r2 = r2_score(y, y_pred)\n",
|
|
||||||
" print('r2 score = %0.2f' % r2)\n",
|
|
||||||
" print('sample predictions:')\n",
|
|
||||||
" print(y_pred[3:6])\n",
|
|
||||||
" print('+----------------------+')\n",
|
|
||||||
" if nodraw:\n",
|
" if nodraw:\n",
|
||||||
" return\n",
|
" return\n",
|
||||||
" plt.plot(X, y_pred, color=color, label=\"Log regression\")\n",
|
" plt.plot(X, y_pred, color=color, label=\"Log regression\")\n",
|
||||||
@@ -303,6 +288,18 @@
|
|||||||
" filename = base_filename % (self.language, self.country)\n",
|
" filename = base_filename % (self.language, self.country)\n",
|
||||||
" plt.savefig(filename.replace(' ', '-'), bbox_inches='tight')\n",
|
" plt.savefig(filename.replace(' ', '-'), bbox_inches='tight')\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"def show_model_stats(coef, intercept, y_test, y_pred, label):\n",
|
||||||
|
" print('+----------------------+')\n",
|
||||||
|
" print(label)\n",
|
||||||
|
" print('coefficient = %0.2f' % coef)\n",
|
||||||
|
" print('intercept = %0.2f' % intercept)\n",
|
||||||
|
" rmse = root_mean_squared_error(y_test, y_pred)\n",
|
||||||
|
" print('rmse = %0.2f' % rmse)\n",
|
||||||
|
" r2 = r2_score(y_test, y_pred)\n",
|
||||||
|
" print('r2 score = %0.2f' % r2)\n",
|
||||||
|
" print('sample predictions:')\n",
|
||||||
|
" print(y_pred[3:6])\n",
|
||||||
|
" print('+----------------------+')\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# the higher a is, the steeper the line gets\n",
|
"# the higher a is, the steeper the line gets\n",
|
||||||
"def log_base_a(x, a=1.07):\n",
|
"def log_base_a(x, a=1.07):\n",
|
||||||
|
Reference in New Issue
Block a user