Added function show_model_stats().

This commit is contained in:
2025-04-27 16:23:33 -07:00
parent 3c3e804251
commit 1f7fe33915

View File

@@ -240,8 +240,10 @@
" X = X + x_shift\n",
" y = y + y_shift\n",
" \n",
" X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,\n",
" random_state=random)\n",
" X_train, X_test, y_train, y_test = train_test_split(\n",
" X, y, \n",
" test_size=0.2, \n",
" random_state=random)\n",
"\n",
" model = LinearRegression()\n",
" model.fit(X_train, y_train)\n",
@@ -249,17 +251,8 @@
" \n",
" m = model.coef_[0][0]\n",
" b = model.intercept_[0]\n",
" print('+----------------------+')\n",
" print('%s regression line for %s' % (color, self.language))\n",
" print('coefficient = %0.2f' % m)\n",
" print('intercept = %0.2f' % b)\n",
" rmse = root_mean_squared_error(y_test, y_pred)\n",
" print('rmse = %0.2f' % rmse)\n",
" r2 = r2_score(y_test, y_pred)\n",
" print('r2 score = %0.2f' % r2)\n",
" print('sample predictions:')\n",
" print(y_pred[3:6])\n",
" print('+----------------------+')\n",
" label = '%s log regression line for %s' % (color, self.language)\n",
" show_model_stats(m, b, y_test, y_pred, label)\n",
"\n",
" plt.figure(self.canvas)\n",
" plt.plot(X_test, y_pred, color=color, label=name, linestyle=style)\n",
@@ -282,17 +275,9 @@
"\n",
" m = model.coef_[0][0]\n",
" b = model.intercept_[0]\n",
" print('+----------------------+')\n",
" print('%s log regression line for %s' % (color, self.language))\n",
" print('coefficient = %0.2f' % m)\n",
" print('intercept = %0.2f' % b)\n",
" rmse = root_mean_squared_error(y, y_pred)\n",
" print('rmse = %0.2f' % rmse)\n",
" r2 = r2_score(y, y_pred)\n",
" print('r2 score = %0.2f' % r2)\n",
" print('sample predictions:')\n",
" print(y_pred[3:6])\n",
" print('+----------------------+')\n",
" label = '%s log regression line for %s' % (color, self.language)\n",
" show_model_stats(m, b, y, y_pred, label)\n",
"\n",
" if nodraw:\n",
" return\n",
" plt.plot(X, y_pred, color=color, label=\"Log regression\")\n",
@@ -303,6 +288,18 @@
" filename = base_filename % (self.language, self.country)\n",
" plt.savefig(filename.replace(' ', '-'), bbox_inches='tight')\n",
"\n",
"def show_model_stats(coef, intercept, y_test, y_pred, label):\n",
" print('+----------------------+')\n",
" print(label)\n",
" print('coefficient = %0.2f' % coef)\n",
" print('intercept = %0.2f' % intercept)\n",
" rmse = root_mean_squared_error(y_test, y_pred)\n",
" print('rmse = %0.2f' % rmse)\n",
" r2 = r2_score(y_test, y_pred)\n",
" print('r2 score = %0.2f' % r2)\n",
" print('sample predictions:')\n",
" print(y_pred[3:6])\n",
" print('+----------------------+')\n",
"\n",
"# the higher a is, the steeper the line gets\n",
"def log_base_a(x, a=1.07):\n",