今天跟着莫凡Python教程里学习的时候,跟着视频敲的这段代码,为什么得不到该有的效果呢?
代码:
```python
from __future__ import print_function
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
def add_layer(inputs, in_size, out_size, activation_function=None):
Weights = tf.Variable(tf.random_normal([in_size, out_size]))
biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
Wx_plus_b = tf.matmul(inputs, Weights) + biases
if activation_function is None:
outputs = Wx_plus_b
else:
outputs = activation_function(Wx_plus_b)
return outputs
# Make up some real data
x_data = np.linspace(-1, 1, 300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise
##plt.scatter(x_data, y_data)
##plt.show()
# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])
# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)
# the error between prediction and real data
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys-prediction), reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
# important step
sess = tf.Session()
# tf.initialize_all_variables() no long valid from
# 2017-03-02 if using tensorflow >= 0.12
if int((tf.__version__).split('.')[1]) < 12 and int((tf.__version__).split('.')[0]) < 1:
init = tf.initialize_all_variables()
else:
init = tf.global_variables_initializer()
sess.run(init)
# plot the real data
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.scatter(x_data, y_data)
plt.ion()
plt.show()
for i in range(1000):
# training
sess.run(train_step, feed_dict={xs: x_data, ys: y_data})
if i % 50 == 0:
# to visualize the result and improvement
try:
ax.lines.remove(lines[0])
except Exception:
pass
prediction_value = sess.run(prediction, feed_dict={xs: x_data})
# plot the prediction
lines = ax.plot(x_data, prediction_value, 'r-', lw=5)
plt.pause(1)
预期效果:
运行结果:
你得设置个颜色吧
```python
from matplotlib import markers
import matplotlib.pyplot as plt
import numpy as np
"""
Pyplot 是 Matplotlib 的子库,提供了和 MATLAB 类似的绘图 API。
Pyplot 是常用的绘图模块,能很方便让用户绘制 2D 图表。
Pyplot 包含一系列绘图函数的相关函数,每个函数会对当前的图像进行一些修改,例如:给图像加上标记,生新的图像,在图像中产生新的绘图区域等等。
"""
# 多个点的绘制
ypoints = np.array([1,3,4,5,8,9,6,1,3,4,5,2,4])
ypoints_one = np.array(range(2,17))
"""
plot() 用于画图它可以绘制点和线,语法格式如下:
# 画单条线
plot([x], y, [fmt], *, data=None, **kwargs)
# 画多条线
plot([x], y, [fmt], [x2], y2, [fmt2], ..., **kwargs)
参数说明:
x, y:点或线的节点 x 为 x 轴数据 y 为 y 轴数据,数据可以列表或数组。
fmt:可选定义基本格式(如颜色、标记和线条样式)。
**kwargs:可选,用在二维平面图上 设置指定属性,如标签,线的宽度等。
"""
"""
'solid' (默认) '-' 实线
'dotted' ':' 点虚线
'dashed' '--' 破折线
'dashdot' '-.' 点划线
'None' '' 或 ' ' 不画线
"""
## fmt = '[marker][line][color]' 例如 o:r,o 表示实心圆标记,: 表示虚线,r 表示颜色为红色。
# plt.plot(ypoints,marker="o")
## 点虚线显示
plt.plot(ypoints,linestyle="dotted",color="g")
plt.plot(ypoints_one,linestyle="dashdot",color="r")
plt.show()
```