通过字符串的方式调用模型——python

最近又搞了一遍PGGAN,tensorflow版本的(progressive_growing_of_gans),里面分装的太好了,我在本地又没有对应的运行环境,通过在服务器上+大脑调试,简直头都大了,最后差不多摸清楚了整个网络的过程,其中通过字符串的方式调用相应的网络比较由意思,即可方便实现对网络进行扩展(都不用在文件中新增import [networks])。

比如有文件netowrks.py如下:

class Net:
    def __init__(self,):
        pass

如需要使用netowrks.py包里的Net类,一般在使用前需要from networks import Net
如果使用字符串实现调用,如以下几步:

  1. 根据定义的字符串解析出包和类名
  2. 导入包
  3. 在导入的包中找到相应的类

如对于输入func='netwokrs.Net'的输入处理如下:

# 将字符串解析为包名和类名
def import_module(module_or_obj_name):
    parts = module_or_obj_name.split('.')
    parts[0] = {'np': 'numpy', 'tf': 'tensorflow'}.get(parts[0], parts[0])
    # 使用for循环,处理多个文件夹情况,如(from networks.network import Net)
    for i in range(len(parts), 0, -1):
        try:
            module = importlib.import_module('.'.join(parts[:i]))
            relative_obj_name = '.'.join(parts[i:])
            return module, relative_obj_name
        except ImportError:
            pass
    raise ImportError(module_or_obj_name)

# 在对应的包中找到对应的类
def find_obj_in_module(module, relative_obj_name):
    obj = module
    for part in relative_obj_name.split('.'):
        obj = getattr(obj, part)
    return obj

# 导入类的主调用
def import_obj(obj_name):
    module, relative_obj_name = import_module(obj_name)
    return find_obj_in_module(module, relative_obj_name)

# 导入类,并实现函数调用, **kwargs为调用类的参数
def call_func_by_name(*args, func=None, **kwargs):
    assert func is not None
    return import_obj(func)(*args, **kwargs)

这样,就可以通过输入network.Net的方式实现Net类的导入和调用了。

得到函数的输入变量名

在隐式调用函数时,需要知道需要函数的输入标量名称,方式如下:

input_names = []
 for param in inspect.signature(FUNC).parameters.values():
            if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty:
                input_names.append(param.name)

其中FUNC为需要调用的函数

Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐