From 1f7fe3391533951577f13e719be6f9be38788f92 Mon Sep 17 00:00:00 2001 From: scuti Date: Sun, 27 Apr 2025 16:23:33 -0700 Subject: [PATCH] Added function show_model_stats(). --- stackoverflow-survey.ipynb | 45 ++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/stackoverflow-survey.ipynb b/stackoverflow-survey.ipynb index 0277e01..41a9311 100644 --- a/stackoverflow-survey.ipynb +++ b/stackoverflow-survey.ipynb @@ -240,8 +240,10 @@ " X = X + x_shift\n", " y = y + y_shift\n", " \n", - " X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,\n", - " random_state=random)\n", + " X_train, X_test, y_train, y_test = train_test_split(\n", + " X, y, \n", + " test_size=0.2, \n", + " random_state=random)\n", "\n", " model = LinearRegression()\n", " model.fit(X_train, y_train)\n", @@ -249,17 +251,8 @@ " \n", " m = model.coef_[0][0]\n", " b = model.intercept_[0]\n", - " print('+----------------------+')\n", - " print('%s regression line for %s' % (color, self.language))\n", - " print('coefficient = %0.2f' % m)\n", - " print('intercept = %0.2f' % b)\n", - " rmse = root_mean_squared_error(y_test, y_pred)\n", - " print('rmse = %0.2f' % rmse)\n", - " r2 = r2_score(y_test, y_pred)\n", - " print('r2 score = %0.2f' % r2)\n", - " print('sample predictions:')\n", - " print(y_pred[3:6])\n", - " print('+----------------------+')\n", + " label = '%s log regression line for %s' % (color, self.language)\n", + " show_model_stats(m, b, y_test, y_pred, label)\n", "\n", " plt.figure(self.canvas)\n", " plt.plot(X_test, y_pred, color=color, label=name, linestyle=style)\n", @@ -282,17 +275,9 @@ "\n", " m = model.coef_[0][0]\n", " b = model.intercept_[0]\n", - " print('+----------------------+')\n", - " print('%s log regression line for %s' % (color, self.language))\n", - " print('coefficient = %0.2f' % m)\n", - " print('intercept = %0.2f' % b)\n", - " rmse = root_mean_squared_error(y, y_pred)\n", - " print('rmse = %0.2f' % rmse)\n", - " r2 = r2_score(y, y_pred)\n", - " print('r2 score = %0.2f' % r2)\n", - " print('sample predictions:')\n", - " print(y_pred[3:6])\n", - " print('+----------------------+')\n", + " label = '%s log regression line for %s' % (color, self.language)\n", + " show_model_stats(m, b, y, y_pred, label)\n", + "\n", " if nodraw:\n", " return\n", " plt.plot(X, y_pred, color=color, label=\"Log regression\")\n", @@ -303,6 +288,18 @@ " filename = base_filename % (self.language, self.country)\n", " plt.savefig(filename.replace(' ', '-'), bbox_inches='tight')\n", "\n", + "def show_model_stats(coef, intercept, y_test, y_pred, label):\n", + " print('+----------------------+')\n", + " print(label)\n", + " print('coefficient = %0.2f' % coef)\n", + " print('intercept = %0.2f' % intercept)\n", + " rmse = root_mean_squared_error(y_test, y_pred)\n", + " print('rmse = %0.2f' % rmse)\n", + " r2 = r2_score(y_test, y_pred)\n", + " print('r2 score = %0.2f' % r2)\n", + " print('sample predictions:')\n", + " print(y_pred[3:6])\n", + " print('+----------------------+')\n", "\n", "# the higher a is, the steeper the line gets\n", "def log_base_a(x, a=1.07):\n",