![]() |
#2
太阳井2019-04-29 09:58
|
![](zzz/editor/img/code.gif)
'''
Parallel data loading functions 并行数据加载功能
'''
import sys
import time
import theano
import numpy as np
import traceback
from PIL import Image
from six.moves import queue
from multiprocessing import Process, Event
from lib.config import cfg
from lib.data_augmentation import preprocess_img
from lib.data_io import get_voxel_file, get_rendering_file
from lib.binvox_rw import read_as_3d_array
def print_error(func):
'''Flush out error messages. Mainly used for debugging separate processes'''
#清除错误消息。主要用于调试单独的过程
def func_wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except:
traceback.print_exception(*sys.exc_info())
sys.stdout.flush()
return func_wrapper
class DataProcess(Process):
def __init__(self, data_queue, data_paths, repeat=True):
'''
data_queue : Multiprocessing queue多处理队列
data_paths : list of data and label pair used to load data用于加载数据的数据列表和标签对
repeat : if set True, return data until exit is set 重复:如果设置为真,返回数据直到设置了退出
'''
super(DataProcess, self).__init__()
# Queue to transfer the loaded mini batches队列以传输加载的小批次
self.data_queue = data_queue
self.data_paths = data_paths
self.num_data = len(data_paths)
self.repeat = repeat
# Tuple of data shape数据形状的元组
self.batch_size = cfg.CONST.BATCH_SIZE
self.exit = Event()
self.shuffle_db_inds()
def shuffle_db_inds(self):
# Randomly permute the training roidb 随机排列训练ROIDB
# (roidb是由字典组成的list,包含了该图片索引所包含的roi信息)
if self.repeat:
self.perm = np.random.permutation(np.arange(self.num_data))
else:
self.perm = np.arange(self.num_data)
self.cur = 0
def get_next_minibatch(self):
if (self.cur + self.batch_size) >= self.num_data and self.repeat:
self.shuffle_db_inds()
db_inds = self.perm[self.cur:min(self.cur + self.batch_size, self.num_data)]
self.cur += self.batch_size
return db_inds
def shutdown(self):
self.exit.set()
@print_error
def run(self):
iteration = 0
# Run the loop until exit flag is set运行循环,直到设置了退出标志
while not self.exit.is_set() and self.cur <= self.num_data:
# Ensure that the network sees (almost) all data per epoch
# 确保网络可以(几乎)看到每个时代的所有数据
db_inds = self.get_next_minibatch()
data_list = []
label_list = []
for batch_id, db_ind in enumerate(db_inds):
datum = self.load_datum(self.data_paths[db_ind])
label = self.load_label(self.data_paths[db_ind])
data_list.append(datum)
label_list.append(label)
batch_data = np.array(data_list).astype(np.float32)
batch_label = np.array(label_list).astype(np.float32)
# The following will wait until the queue frees以下操作将等待队列释放
self.data_queue.put((batch_data, batch_label), block=True)
iteration += 1
def load_datum(self, path):
pass
def load_label(self, path):
pass
class ReconstructionDataProcess(DataProcess):
def __init__(self, data_queue, category_model_pair, background_imgs=[], repeat=True,
train=True):
self.repeat = repeat
self.train = train
self.background_imgs = background_imgs
super(ReconstructionDataProcess, self).__init__(
data_queue, category_model_pair, repeat=repeat)
@print_error #先执行修饰器函数
def run(self):
# set up constants设置常量
img_h = cfg.CONST.IMG_W
img_w = cfg.CONST.IMG_H
n_vox = cfg.CONST.N_VOX
# This is the maximum number of views 这是最大视图数
n_views = cfg.CONST.N_VIEWS
while not self.exit.is_set() and self.cur <= self.num_data:
# To insure that the network sees (almost) all images per epoch
# 以确保网络能够(几乎)看到每个时代的所有图像
db_inds = self.get_next_minibatch()
# We will sample # views 我们将采样视图
if cfg.TRAIN.RANDOM_NUM_VIEWS:
curr_n_views = np.random.randint(n_views) + 1
else:
curr_n_views = n_views
# This will be fed into the queue. create new batch everytime
# 这将被送入队列。每次创建新批
batch_img = np.zeros(
(curr_n_views, self.batch_size, 3, img_h, img_w), dtype=theano.config.floatX)
batch_voxel = np.zeros(
(self.batch_size, n_vox, 2, n_vox, n_vox), dtype=theano.config.floatX)
# load each data instance 加载每个数据实例
for batch_id, db_ind in enumerate(db_inds):
category, model_id = self.data_paths[db_ind]
image_ids = np.random.choice(cfg.TRAIN.NUM_RENDERING, curr_n_views)
# load multi view images 加载多视图图像
for view_id, image_id in enumerate(image_ids):
im = self.load_img(category, model_id, image_id)
# channel, height, width 通道,高度,宽度
batch_img[view_id, batch_id, :, :, :] = \
im.transpose((2, 0, 1)).astype(theano.config.floatX)
voxel = self.load_label(category, model_id)
voxel_data = voxel.data
batch_voxel[batch_id, :, 0, :, :] = voxel_data < 1
batch_voxel[batch_id, :, 1, :, :] = voxel_data
# The following will wait until the queue frees 以下操作将等待队列释放
self.data_queue.put((batch_img, batch_voxel), block=True)
print('Exiting')
def load_img(self, category, model_id, image_id):
image_fn = get_rendering_file(category, model_id, image_id)
im = Image.open(image_fn)
t_im = preprocess_img(im, self.train)
return t_im
def load_label(self, category, model_id):
voxel_fn = get_voxel_file(category, model_id)
with open(voxel_fn, 'rb') as f:
voxel = read_as_3d_array(f)
return voxel
def kill_processes(queue, processes):
print('Signal processes')
for p in processes:
p.shutdown()
print('Empty queue')
while not queue.empty():
time.sleep(0.5)
queue.get(False)
print('kill processes')
for p in processes:
p.terminate()
def make_data_processes(queue, data_paths, num_workers, repeat=True, train=True):
'''
Make a set of data processes for parallel data loading. 为并行数据加载创建一组数据处理。
'''
processes = []
for i in range(num_workers):
process = ReconstructionDataProcess(queue, data_paths, repeat=repeat, train=train)
process.start()
processes.append(process)
return processes
def get_while_running(data_process, data_queue, sleep_time=0):
while True:
time.sleep(sleep_time)
try:
batch_data, batch_label = data_queue.get_nowait()
except queue.Empty:
if not data_process.is_alive():
break
else:
continue
yield batch_data, batch_label
def test_process():
from multiprocessing import Queue
from lib.config import cfg
from lib.data_io import category_model_id_pair
cfg.TRAIN.PAD_X = 10
cfg.TRAIN.PAD_Y = 10
data_queue = Queue(2)
category_model_pair = category_model_id_pair(dataset_portion=[0, 0.1])
data_process = ReconstructionDataProcess(data_queue, category_model_pair)
data_process.start()
batch_img, batch_voxel = data_queue.get()
kill_processes(data_queue, [data_process])
if __name__ == '__main__':
test_process()
Parallel data loading functions 并行数据加载功能
'''
import sys
import time
import theano
import numpy as np
import traceback
from PIL import Image
from six.moves import queue
from multiprocessing import Process, Event
from lib.config import cfg
from lib.data_augmentation import preprocess_img
from lib.data_io import get_voxel_file, get_rendering_file
from lib.binvox_rw import read_as_3d_array
def print_error(func):
'''Flush out error messages. Mainly used for debugging separate processes'''
#清除错误消息。主要用于调试单独的过程
def func_wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except:
traceback.print_exception(*sys.exc_info())
sys.stdout.flush()
return func_wrapper
class DataProcess(Process):
def __init__(self, data_queue, data_paths, repeat=True):
'''
data_queue : Multiprocessing queue多处理队列
data_paths : list of data and label pair used to load data用于加载数据的数据列表和标签对
repeat : if set True, return data until exit is set 重复:如果设置为真,返回数据直到设置了退出
'''
super(DataProcess, self).__init__()
# Queue to transfer the loaded mini batches队列以传输加载的小批次
self.data_queue = data_queue
self.data_paths = data_paths
self.num_data = len(data_paths)
self.repeat = repeat
# Tuple of data shape数据形状的元组
self.batch_size = cfg.CONST.BATCH_SIZE
self.exit = Event()
self.shuffle_db_inds()
def shuffle_db_inds(self):
# Randomly permute the training roidb 随机排列训练ROIDB
# (roidb是由字典组成的list,包含了该图片索引所包含的roi信息)
if self.repeat:
self.perm = np.random.permutation(np.arange(self.num_data))
else:
self.perm = np.arange(self.num_data)
self.cur = 0
def get_next_minibatch(self):
if (self.cur + self.batch_size) >= self.num_data and self.repeat:
self.shuffle_db_inds()
db_inds = self.perm[self.cur:min(self.cur + self.batch_size, self.num_data)]
self.cur += self.batch_size
return db_inds
def shutdown(self):
self.exit.set()
@print_error
def run(self):
iteration = 0
# Run the loop until exit flag is set运行循环,直到设置了退出标志
while not self.exit.is_set() and self.cur <= self.num_data:
# Ensure that the network sees (almost) all data per epoch
# 确保网络可以(几乎)看到每个时代的所有数据
db_inds = self.get_next_minibatch()
data_list = []
label_list = []
for batch_id, db_ind in enumerate(db_inds):
datum = self.load_datum(self.data_paths[db_ind])
label = self.load_label(self.data_paths[db_ind])
data_list.append(datum)
label_list.append(label)
batch_data = np.array(data_list).astype(np.float32)
batch_label = np.array(label_list).astype(np.float32)
# The following will wait until the queue frees以下操作将等待队列释放
self.data_queue.put((batch_data, batch_label), block=True)
iteration += 1
def load_datum(self, path):
pass
def load_label(self, path):
pass
class ReconstructionDataProcess(DataProcess):
def __init__(self, data_queue, category_model_pair, background_imgs=[], repeat=True,
train=True):
self.repeat = repeat
self.train = train
self.background_imgs = background_imgs
super(ReconstructionDataProcess, self).__init__(
data_queue, category_model_pair, repeat=repeat)
@print_error #先执行修饰器函数
def run(self):
# set up constants设置常量
img_h = cfg.CONST.IMG_W
img_w = cfg.CONST.IMG_H
n_vox = cfg.CONST.N_VOX
# This is the maximum number of views 这是最大视图数
n_views = cfg.CONST.N_VIEWS
while not self.exit.is_set() and self.cur <= self.num_data:
# To insure that the network sees (almost) all images per epoch
# 以确保网络能够(几乎)看到每个时代的所有图像
db_inds = self.get_next_minibatch()
# We will sample # views 我们将采样视图
if cfg.TRAIN.RANDOM_NUM_VIEWS:
curr_n_views = np.random.randint(n_views) + 1
else:
curr_n_views = n_views
# This will be fed into the queue. create new batch everytime
# 这将被送入队列。每次创建新批
batch_img = np.zeros(
(curr_n_views, self.batch_size, 3, img_h, img_w), dtype=theano.config.floatX)
batch_voxel = np.zeros(
(self.batch_size, n_vox, 2, n_vox, n_vox), dtype=theano.config.floatX)
# load each data instance 加载每个数据实例
for batch_id, db_ind in enumerate(db_inds):
category, model_id = self.data_paths[db_ind]
image_ids = np.random.choice(cfg.TRAIN.NUM_RENDERING, curr_n_views)
# load multi view images 加载多视图图像
for view_id, image_id in enumerate(image_ids):
im = self.load_img(category, model_id, image_id)
# channel, height, width 通道,高度,宽度
batch_img[view_id, batch_id, :, :, :] = \
im.transpose((2, 0, 1)).astype(theano.config.floatX)
voxel = self.load_label(category, model_id)
voxel_data = voxel.data
batch_voxel[batch_id, :, 0, :, :] = voxel_data < 1
batch_voxel[batch_id, :, 1, :, :] = voxel_data
# The following will wait until the queue frees 以下操作将等待队列释放
self.data_queue.put((batch_img, batch_voxel), block=True)
print('Exiting')
def load_img(self, category, model_id, image_id):
image_fn = get_rendering_file(category, model_id, image_id)
im = Image.open(image_fn)
t_im = preprocess_img(im, self.train)
return t_im
def load_label(self, category, model_id):
voxel_fn = get_voxel_file(category, model_id)
with open(voxel_fn, 'rb') as f:
voxel = read_as_3d_array(f)
return voxel
def kill_processes(queue, processes):
print('Signal processes')
for p in processes:
p.shutdown()
print('Empty queue')
while not queue.empty():
time.sleep(0.5)
queue.get(False)
print('kill processes')
for p in processes:
p.terminate()
def make_data_processes(queue, data_paths, num_workers, repeat=True, train=True):
'''
Make a set of data processes for parallel data loading. 为并行数据加载创建一组数据处理。
'''
processes = []
for i in range(num_workers):
process = ReconstructionDataProcess(queue, data_paths, repeat=repeat, train=train)
process.start()
processes.append(process)
return processes
def get_while_running(data_process, data_queue, sleep_time=0):
while True:
time.sleep(sleep_time)
try:
batch_data, batch_label = data_queue.get_nowait()
except queue.Empty:
if not data_process.is_alive():
break
else:
continue
yield batch_data, batch_label
def test_process():
from multiprocessing import Queue
from lib.config import cfg
from lib.data_io import category_model_id_pair
cfg.TRAIN.PAD_X = 10
cfg.TRAIN.PAD_Y = 10
data_queue = Queue(2)
category_model_pair = category_model_id_pair(dataset_portion=[0, 0.1])
data_process = ReconstructionDataProcess(data_queue, category_model_pair)
data_process.start()
batch_img, batch_voxel = data_queue.get()
kill_processes(data_queue, [data_process])
if __name__ == '__main__':
test_process()
报错如下:
File "/home/nvidia/Downloads/3D-R2N2/lib/data_process.py", line 178, in kill_processes
for p in processes:
TypeError: 'NoneType' object is not iterable