GLADNet 程式碼解釋

2020-08-11 22:04:08

在看影象增強的論文,解析大佬的程式碼,一行一行死磕,先上完整的主函數程式碼。

main.py

from __future__ import print_function
import os
from glob import glob

import tensorflow.compat.v1 as tf

from model import lowlight_enhance
from utils import load_images

# 參數設定
# 執行前一定確定參數設定正確
use_gpu = 0  # 固定數值,不要修改
# gpu_idx = 0
# gpu_mem = 1.5
phase = "test"  # 固定數值,不要修改
save_dir = "./result"  # 想要儲存結果的目錄,按需修改
test_dir = "./pic"  # 想要進行圖片增強處理的圖片位置,按需修改


def lowlight_train(lowlight_enhance):
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
    if not os.path.exists(sample_dir):
        os.makedirs(sample_dir)

    train_low_data = []
    train_high_data = []

    train_low_data_names = glob('/mnt/hdd/wangwenjing/FGtraining/low/*.png')  # ./data/train/low/*.png')
    train_low_data_names.sort()
    train_high_data_names = glob('/mnt/hdd/wangwenjing/FGtraining/normal/*.png')  # ./data/train/normal/*.png')
    train_high_data_names.sort()
    assert len(train_low_data_names) == len(train_high_data_names)
    print('[*] Number of training data: %d' % len(train_low_data_names))

    for idx in range(len(train_low_data_names)):
        if (idx + 1) % 1000 == 0:
            print(idx + 1)
        low_im = load_images(train_low_data_names[idx])
        train_low_data.append(low_im)
        high_im = load_images(train_high_data_names[idx])
        train_high_data.append(high_im)

    eval_low_data = []
    eval_high_data = []

    eval_low_data_name = glob('./data/eval/low/*.*')

    for idx in range(len(eval_low_data_name)):
        eval_low_im = load_images(eval_low_data_name[idx])
        eval_low_data.append(eval_low_im)

    lowlight_enhance.train(train_low_data, train_high_data, eval_low_data, batch_size=batch_size, patch_size=patch_size,
                           epoch=epoch, sample_dir=sample_dir, ckpt_dir=ckpt_dir, eval_every_epoch=eval_every_epoch)


def lowlight_test(lowlight_enhance):
    if not test_dir:
        print("[!] please provide --test_dir")
        exit(0)

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    test_low_data_name = glob(os.path.join(test_dir) + '/*.*')
    test_low_data = []
    test_high_data = []
    for idx in range(len(test_low_data_name)):
        test_low_im = load_images(test_low_data_name[idx])
        test_low_data.append(test_low_im)

    lowlight_enhance.test(test_low_data, test_high_data, test_low_data_name, save_dir=save_dir)


def main(_):
    if use_gpu:
        print("[*] GPU\n")
        # os.environ["CUDA_VISIBLE_DEVICES"] = gpu_idx
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_mem)
        with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
            model = lowlight_enhance(sess)
            if phase == 'train':
                lowlight_train(model)
            elif phase == 'test':
                lowlight_test(model)
            else:
                print('[!] Unknown phase')
                exit(0)
    else:
        print("[*] CPU\n")
        with tf.Session() as sess:
            model = lowlight_enhance(sess)
            if phase == 'train':
                lowlight_train(model)
            elif phase == 'test':
                lowlight_test(model)
            else:
                print('[!] Unknown phase')
                exit(0)


if __name__ == '__main__':
    tf.app.run()
  • 開始解析
from __future__ import print_function

在開頭加上這句之後,即使在python2.X,使用print就得像python3.X那樣加括號使用。python2.X中print不需要括號,而在python3.X中則需加括號。如果某個版本中出現了某個新的功能特性,而且這個特性和當前版本中使用的不相容,也就是它在該版本中不是語言標準,那麼我如果想要使用的話就需要從future模組匯入。

import os

os是一種常用模組
python import os模組常用函數感謝