初步学习 tensorflow 预测算法 (线性回归方程)
1. 随机生成 1000 个点, 围绕在 y=0.1x+0.3
- num_points=1000
- vectors_set=[]
- for i in range(num_points):
- x1=np.random.normal(0.0,0.55)
- y1=x1*0.1+0.3+np.random.normal(0.0,0.03)
- vectors_set.append([x1,y1])
- # 样本生成
- x_data=[v[0] for v in vectors_set]
- y_data=[v[1] for v in vectors_set]
2. 构建线性回归方程
- # 初始化 W
- W=tf.Variable(tf.random_uniform([1],-1.0,1.0),name='W')
- # 初始化 b
- b=tf.Variable(tf.zeros([1]),name='b')
- # 构建线性回归方程
- y=W*x_data+b
3. 计算损失
- # 计算损失 (均方差)
- loss=tf.reduce_mean(tf.square(y-y_data),name='loss')
4. 优化参数
- # 采用梯度下降方法优化参数
- optimizer=tf.train.GradientDescentOptimizer(0.5)
- # 训练
- train=optimizer.minimize(loss,name='loss')
5. 开始初始化
- # 初始化全局
- sess=tf.Session()
- init=tf.global_variables_initializer()
- sess.run(init)
6. 开始训练
- # 打印初始化 W,b 和 loss
- print("W=",sess.run(W),"b=",sess.run(b),"loss=",sess.run(loss))
- # 开始训练
- for i in range(20):
- sess.run(train)
- print("W=", sess.run(W), "b=", sess.run(b), "loss=", sess.run(loss))
7. 图表展示
- # 图表展示
- plt.scatter(x_data,y_data,c='r')
- plt.plot(x_data,sess.run(W)*x_data+sess.run(b))
- plt.show()
来源: http://www.bubuko.com/infodetail-3716297.html