diff --git a/tf-test.py b/tf-test.py index 2705a85..09cf570 100644 --- a/tf-test.py +++ b/tf-test.py @@ -1,15 +1,5 @@ -# TensorFlow and tf.keras import tensorflow as tf - -# Helper libraries -import numpy as np -import matplotlib.pyplot as plt - -# print(tf.__version__) - -# print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU'))) - -with tf.device("/GPU:0"): - a = tf.random.normal([1, 2]) - - +if tf.test.gpu_device_name(): + print('Default GPU Device: {}'.format(tf.test.gpu_device_name())) +else: + print("Please install GPU version of TF")