通过字符串的方式调用模型——python
通过字符串的方式调用模型——python最近又搞了一遍PGGAN,tensorflow版本的(progressive_growing_of_gans),里面分装的太好了,我在本地又没有对应的运行环境,通过在服务器上+大脑调试,简直头都大了,最后差不多摸清楚了整个网络的过程,其中通过字符串的方式调用相应的网络比较由意思,即可方便实现对网络进行扩展(都不用在文件中新增import [networks.
·
通过字符串的方式调用模型——python
最近又搞了一遍PGGAN,tensorflow版本的(progressive_growing_of_gans),里面分装的太好了,我在本地又没有对应的运行环境,通过在服务器上+大脑调试,简直头都大了,最后差不多摸清楚了整个网络的过程,其中通过字符串的方式调用相应的网络比较由意思,即可方便实现对网络进行扩展(都不用在文件中新增import [networks]
)。
比如有文件netowrks.py
如下:
class Net:
def __init__(self,):
pass
如需要使用netowrks.py
包里的Net
类,一般在使用前需要from networks import Net
,
如果使用字符串实现调用,如以下几步:
- 根据定义的字符串解析出包和类名
- 导入包
- 在导入的包中找到相应的类
如对于输入func='netwokrs.Net'
的输入处理如下:
# 将字符串解析为包名和类名
def import_module(module_or_obj_name):
parts = module_or_obj_name.split('.')
parts[0] = {'np': 'numpy', 'tf': 'tensorflow'}.get(parts[0], parts[0])
# 使用for循环,处理多个文件夹情况,如(from networks.network import Net)
for i in range(len(parts), 0, -1):
try:
module = importlib.import_module('.'.join(parts[:i]))
relative_obj_name = '.'.join(parts[i:])
return module, relative_obj_name
except ImportError:
pass
raise ImportError(module_or_obj_name)
# 在对应的包中找到对应的类
def find_obj_in_module(module, relative_obj_name):
obj = module
for part in relative_obj_name.split('.'):
obj = getattr(obj, part)
return obj
# 导入类的主调用
def import_obj(obj_name):
module, relative_obj_name = import_module(obj_name)
return find_obj_in_module(module, relative_obj_name)
# 导入类,并实现函数调用, **kwargs为调用类的参数
def call_func_by_name(*args, func=None, **kwargs):
assert func is not None
return import_obj(func)(*args, **kwargs)
这样,就可以通过输入network.Net
的方式实现Net
类的导入和调用了。
得到函数的输入变量名
在隐式调用函数时,需要知道需要函数的输入标量名称,方式如下:
input_names = []
for param in inspect.signature(FUNC).parameters.values():
if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty:
input_names.append(param.name)
其中FUNC
为需要调用的函数
更多推荐
所有评论(0)