# -*- coding: utf-8 -*-
from numpy import *
import pandas as pd
###线性回归####
# 读取数据
data = pd.read_csv('Advertising.csv', index_col=0)
#前五条数据
data.head()
#后五条数据
data.tail()
# 画散点图
import seaborn as sns
import matplotlib
'''
% matplotlib
inline
'''
#画散点图
sns.pairplot(data, x_vars=['TV', 'radio', 'newspaper'], y_vars='sales', size=7, aspect=0.8)
sns.pairplot(data, x_vars=['TV', 'radio', 'newspaper'], y_vars='sales', size=7, aspect=0.8, kind='reg')
sns.pairplot(data,x_vars=['TV', 'radio', 'newspaper'], y_vars='sales')
# 计算相关系数矩阵
data.corr()
# 构建X、Y数据集
X = data[['TV', 'radio', 'newspaper']]
X.head()
y = data['sales']
y.head()
#通过正规方程求解最小二乘
##直接根据系数矩阵公式计算
def standRegres(xArr, yArr):
xMat = mat(xArr);
yMat = mat(yArr).T
xTx = xMat.T * xMat
if linalg.det(xTx) == 0.0:
print("This matrix is singular, cannot do inverse") #奇异矩阵不能求逆
return
ws = xTx.I * (xMat.T * yMat)
return ws
# 求解回归方程系数
X2 = X
X2['intercept'] = [1] * 200
standRegres(X2, y)
##利用sklearn求解
from sklearn.linear_model import LinearRegression
linreg = LinearRegression()
linreg.fit(X, y)
#打印偏置
print(linreg.intercept_)
#打印相关系数
print(linreg.coef_)
print(zip(['TV', 'radio', 'newspaper'], linreg.coef_))
##测试集和训练集的构建 交叉验证
from sklearn.cross_validation import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1)
linreg.fit(X_train, y_train)
# 结果
print(linreg.intercept_)
print(linreg.coef_)
print(zip(['TV', 'Radio', 'Newspaper'], linreg.coef_))
# 预测
y_pred = linreg.predict(X_test)
# 误差评估
from sklearn import metrics
# calculate MAE using scikit-learn
print("MAE:", metrics.mean_absolute_error(y_test, y_pred))
# calculate MSE using scikit-learn
print("MSE:", metrics.mean_squared_error(y_test, y_pred))
# calculate RMSE using scikit-learn
print("RMSE:", sqrt(metrics.mean_squared_error(y_test, y_pred)))
##去除newspaper特征 模型比较
feature_cols = ['TV', 'radio']
X = data[feature_cols]
y = data.sales
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1)
linreg.fit(X_train, y_train)
y_pred = linreg.predict(X_test)
# calculate MAE using scikit-learn
print("MAE:", metrics.mean_absolute_error(y_test, y_pred))
# calculate MSE using scikit-learn
print("MSE:", metrics.mean_squared_error(y_test, y_pred))
# calculate RMSE using scikit-learn
print("RMSE:", sqrt(metrics.mean_squared_error(y_test, y_pred)))
米斯特66
- 粉丝: 5
- 资源: 11
最新资源
- Minecraft-flan 耐久插件
- 【java毕业设计】枣庄美食街网站源码(ssm+mysql+说明文档).zip
- jspm基于JSP的学生社团管理系统v5bo2.zip
- 【java毕业设计】学生信息管理系统源码(ssm+mysql+说明文档+LW).zip
- mysql8.0.40.0-windows安装包
- 【java毕业设计】新冠疫情下的校园出入系统源码(ssm+mysql+说明文档+LW).zip
- 【java毕业设计】校园二手交易系统源码(ssm+mysql+说明文档).zip
- mysql5.7.44.0-windows安装包
- 【java毕业设计】烯烃厂压力管道管理平台源码(ssm+mysql+说明文档+LW).zip
- 【java毕业设计】面向学生成绩分析系统源码(ssm+mysql+说明文档+LW).zip
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈