代码为:
a = tf.ones([3, 2])
b = tf.fill([2, 3], 3.)
print(tf.matmul(a, b))
在网上搜下了,发现这是由于显存的问题导致的,因为默认情况下在代码中使用GPU时,有把内存占满的趋势;即使有时候计算的数据量并不足以占用整个GPU。
我们可以将GPU设置为memory_growth模式,即运行时初始化将不会在设备上分配所有显存,需要多少显存就用多少显存,实现方法如下:
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)
如果有多个GPU:
physical_devices = tf.config.list_physical_devices('GPU')
for physical_device in physical_devices:
tf.config.experimental.set_memory_growth(physical_device, True)
完整代码(成功运行):
import tensorflow as tf
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)
a = tf.ones([3, 2])
b = tf.fill([2, 3], 3.)
print(tf.matmul(a, b))
最后,更优雅的方式:
Tensorflow2.1 -> 2.3,解决!
pip install --upgrade tensorflow -i https://pypi.org/simple
为什么要加 -i https://pypi.org/simple
,因为清华源在我这直接速度是kb级且才开始下一会儿就提示超时,所以:
pip config set global.index-url https://pypi.org/simple
不过还是备注一下,清华源:https://pypi.tuna.tsinghua.edu.cn/simple
参考链接:
- https://blog.csdn.net/qq_39507748/article/details/104273521
- https://www.tensorflow.org/api_docs/python/tf/config/experimental/set_memory_growth
此处评论已关闭