tensorflow线性回归预测鲍鱼数据

代码如下:

import tensorflow as tf
import csv
import numpy as np
import matplotlib.pyplot as plt
# 设置学习率
learning_rate = 0.01
# 设置训练次数
train_steps = 1000
#数据地址:http://archive.ics.uci.edu/ml/datasets/Abalone
with open(‘./data/abalone.data‘) as file:
    reader = csv.reader(file)
    a, b = [], []
    for item in reader:
        b.append(item[8])
        del(item[8])
        a.append(item)
    file.close()
x_data = np.array(a)
new_x_data = []
for i in x_data[:,0]:
    if i == ‘M‘:
        i = 1
    elif i == ‘F‘:
        i = 2
    elif i == ‘I‘:
        i = 3
    new_x_data.append(i)
new_data = np.array(new_x_data)
x_data = np.delete(x_data,0,axis=1)
print(x_data.shape)
print(new_data.shape)
x_data = np.column_stack((new_data,x_data)) #添加一列,将new_data添加到x_data中
print(x_data)
print(x_data[:,0])
y_data = np.array(b)
for i in range(len(x_data)):
    y_data[i] = float(y_data[i])
    for j in range(len(x_data[i])):
        x_data[i][j] = float(x_data[i][j])
# 定义各影响因子的权重
weights = tf.Variable(np.ones([8,1]),dtype = tf.float32)
x_data_ = tf.placeholder(tf.float32, [None, 8])
y_data_ = tf.placeholder(tf.float32, [None, 1])
bias = tf.Variable(1.0, dtype = tf.float32)#定义偏差值
# 构建模型为:y_model = w1X1 + w2X2 + w3X3 + w4X4 + w5X5 + w6X6 + w7X7 + w8X8 + bias
y_model = tf.add(tf.matmul(x_data_ , weights), bias)
# 定义损失函数
loss = tf.reduce_mean(tf.pow((y_model - y_data_), 2))
#训练目标为损失值最小,学习率为0.01
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print("Start training!")
    lo = []
    sample = np.arange(train_steps)
    for i in range(train_steps):
        for (x,y) in zip(x_data, y_data):
            z1 = x.reshape(1,8)
            z2 = y.reshape(1,1)
            sess.run(train_op, feed_dict = {x_data_ : z1, y_data_ : z2})
        l = sess.run(loss, feed_dict = {x_data_ : z1, y_data_ : z2})
        lo.append(l)
    print(weights.eval(sess))
    print(bias.eval(sess))
    # 绘制训练损失变化图
    plt.plot(sample, lo, marker="*", linewidth=1, linestyle="--", color="red")
    plt.title("The variation of the loss")
    plt.xlabel("Sampling Point")
    plt.ylabel("Loss")
    plt.grid(True)
    plt.show()

相关推荐