TensorFlow 調用預訓練好的模型—— Python 實現

1. 準備預訓練好的模型

  • TensorFlow 預訓練好的模型被保存爲以下四個文件

模型文件

  • data 文件是訓練好的參數值,meta 文件是定義的神經網絡圖,checkpoint 文件是所有模型的保存路徑,如下所示,爲簡單起見只保留了一個模型。
model_checkpoint_path: "/home/senius/python/c_python/test/model-40"
all_model_checkpoint_paths: "/home/senius/python/c_python/test/model-40"

2. 導入模型圖、參數值和相關變量

import tensorflow as tf import numpy as np sess = tf.Session() X = None # input yhat = None # output def load_model(): """ Loading the pre-trained model and parameters. """ global X, yhat modelpath = r'/home/senius/python/c_python/test/' saver = tf.train.import_meta_graph(modelpath + 'model-40.meta') saver.restore(sess, tf.train.latest_checkpoint(modelpath)) graph = tf.get_default_graph() X = graph.get_tensor_by_name("X:0") yhat = graph.get_tensor_by_name("tanh:0") print('Successfully load the pre-trained model!') import tensorflow as tf import numpy as np sess = tf.Session() X = None # input yhat = None # output def load_model(): """ Loading the pre-trained model and parameters. """ global X, yhat modelpath = r'/home/senius/python/c_python/test/' saver = tf.train.import_meta_graph(modelpath + 'model-40.meta') saver.restore(sess, tf.train.latest_checkpoint(modelpath)) graph = tf.get_default_graph() X = graph.get_tensor_by_name("X:0") yhat = graph.get_tensor_by_name("tanh:0") print('Successfully load the pre-trained model!')
  • 通過 saver.restore 我們可以得到預訓練的所有參數值,然後再通過 graph.get_tensor_by_name 得到模型的輸入張量和我們想要的輸出張量。

3. 運行前向傳播過程得到預測值

def predict(txtdata): """ Convert data to Numpy array which has a shape of (-1, 41, 41, 41 3). Test a single example. Arg: txtdata: Array in C. Returns: Three coordinates of a face normal. """ global X, yhat data = np.array(txtdata) data = data.reshape(-1, 41, 41, 41, 3) output = sess.run(yhat, feed_dict={X: data}) # (-1, 3) output = output.reshape(-1, 1) ret = output.tolist() return ret def predict(txtdata): """ Convert data to Numpy array which has a shape of (-1, 41, 41, 41 3). Test a single example. Arg: txtdata: Array in C. Returns: Three coordinates of a face normal. """ global X, yhat data = np.array(txtdata) data = data.reshape(-1, 41, 41, 41, 3) output = sess.run(yhat, feed_dict={X: data}) # (-1, 3) output = output.reshape(-1, 1) ret = output.tolist() return ret
  • 通過 feed_dict 喂入測試數據,然後 run 輸出的張量我們就可以得到預測值。

4. 測試

load_model() testdata = np.fromfile('/home/senius/python/c_python/test/04t30t00.npy', dtype=np.float32) testdata = testdata.reshape(-1, 41, 41, 41, 3) # (150, 41, 41, 41, 3) testdata = testdata[0:2, ...] # the first two examples txtdata = testdata.tolist() output = predict(txtdata) print(output) # [[-0.13345889747142792], [0.5858198404312134], [-0.7211828231811523],  # [-0.03778800368309021], [0.9978875517845154], [0.06522832065820694]] load_model() testdata = np.fromfile('/home/senius/python/c_python/test/04t30t00.npy', dtype=np.float32) testdata = testdata.reshape(-1, 41, 41, 41, 3) # (150, 41, 41, 41, 3) testdata = testdata[0:2, ...] # the first two examples txtdata = testdata.tolist() output = predict(txtdata) print(output) # [[-0.13345889747142792], [0.5858198404312134], [-0.7211828231811523],  # [-0.03778800368309021], [0.9978875517845154], [0.06522832065820694]]
  • 本例輸入是一個三維網格模型處理後的 [41, 41, 41, 3] 的數據,輸出一個表面法向量座標 (x, y, z)。

獲取更多精彩,請關注「seniusen」!
seniusen