代码为:

a = tf.ones([3, 2])
b = tf.fill([2, 3], 3.)
print(tf.matmul(a, b))

在网上搜下了,发现这是由于显存的问题导致的,因为默认情况下在代码中使用GPU时,有把内存占满的趋势;即使有时候计算的数据量并不足以占用整个GPU。

我们可以将GPU设置为memory_growth模式,即运行时初始化将不会在设备上分配所有显存,需要多少显存就用多少显存,实现方法如下:

physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

如果有多个GPU:

physical_devices = tf.config.list_physical_devices('GPU')
for physical_device in physical_devices:
    tf.config.experimental.set_memory_growth(physical_device, True)

完整代码(成功运行):

import tensorflow as tf

physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

a = tf.ones([3, 2])
b = tf.fill([2, 3], 3.)
print(tf.matmul(a, b))

最后,更优雅的方式:

Tensorflow2.1 -> 2.3,解决!

pip install --upgrade tensorflow -i https://pypi.org/simple

为什么要加 -i https://pypi.org/simple,因为清华源在我这直接速度是kb级且才开始下一会儿就提示超时,所以:

pip config set global.index-url https://pypi.org/simple

不过还是备注一下,清华源:https://pypi.tuna.tsinghua.edu.cn/simple

参考链接:

  1. https://blog.csdn.net/qq_39507748/article/details/104273521
  2. https://www.tensorflow.org/api_docs/python/tf/config/experimental/set_memory_growth
最后修改:2023 年 08 月 26 日
如果觉得我的文章对你有用,请随意赞赏