https://github.com/mravanelli/pytorch-kaldi/

run_exp.py

https://shiweipku.gitbooks.io/chinese-doc-of-kaldi/content/data_preparation.html
有部分章节空白
https://zhuanlan.zhihu.com/p/82967574
kaldi/语音识别ASR/声纹识别SRE/资源汇总

def _run_forwarding_in_subprocesses(config):
    use_cuda = strtobool(config["exp"]["use_cuda"])
    if use_cuda:
        return False
    else:
        return True

def _is_first_validation(ep, ck, N_ck_tr, config):
    def _get_nr_of_valid_per_epoch_from_config(config):
        if not "nr_of_valid_per_epoch" in config["exp"]:
            return 1
        return int(config["exp"]["nr_of_valid_per_epoch"])
    
    if ep>0:
        return False  
    val_chunks = get_chunks_after_which_to_validate(N_ck_tr, _get_nr_of_valid_per_epoch_from_config(config))
    if ck == val_chunks[0]:
        return True
    return False

def _max_nr_of_parallel_forwarding_processes(config):
    if "max_nr_of_parallel_forwarding_processes" in config["forward"]:
        return int(config["forward"]["max_nr_of_parallel_forwarding_processes"])
    return -1

初始配置,cuda等其他

# Reading global cfg file (first argument-mandatory file)
cfg_file = sys.argv[1]
if not (os.path.exists(cfg_file)):
    sys.stderr.write("ERROR: The config file %s does not exist!\n" % (cfg_file))
    sys.exit(0)
else:
    config = configparser.ConfigParser()
    config.read(cfg_file)

读取global cfg文件,初始化参数

# Reading and parsing optional arguments from command line (e.g.,--optimization,lr=0.002)
[section_args, field_args, value_args] = read_args_command_line(sys.argv, config)

# Output folder creation
out_folder = config["exp"]["out_folder"]
if not os.path.exists(out_folder):
    os.makedirs(out_folder + "/exp_files")

# Log file path
log_file = config["exp"]["out_folder"] + "/log.log"

# Read, parse, and check the config file
cfg_file_proto = config["cfg_proto"]["cfg_proto"]
[config, name_data, name_arch] = check_cfg(cfg_file, config, cfg_file_proto)

# Read cfg file options
is_production = strtobool(config["exp"]["production"])
cfg_file_proto_chunk = config["cfg_proto"]["cfg_proto_chunk"]

cmd = config["exp"]["cmd"]
N_ep = int(config["exp"]["N_epochs_tr"])
N_ep_str_format = "0" + str(max(math.ceil(np.log10(N_ep)), 1)) + "d"
tr_data_lst = config["data_use"]["train_with"].split(",")
valid_data_lst = config["data_use"]["valid_with"].split(",")
forward_data_lst = config["data_use"]["forward_with"].split(",")
max_seq_length_train = config["batches"]["max_seq_length_train"]
forward_save_files = list(map(strtobool, config["forward"]["save_out_file"].split(",")))

读取参数,创建文件路径,读取cfg文件

# Copy the global cfg file into the output folder
cfg_file = out_folder + "/conf.cfg"
with open(cfg_file, "w") as configfile:
    config.write(configfile)

# Load the run_nn function from core libriary
# The run_nn is a function that process a single chunk of data
run_nn_script = config["exp"]["run_nn_script"].split(".py")[0]
module = importlib.import_module("core")
run_nn = getattr(module, run_nn_script)

# Splitting data into chunks (see out_folder/additional_files)
create_lists(config)

# Writing the config files
create_configs(config)

将原始数据划分为不同的chunks

# create res_file
res_file_path = out_folder + "/res.res"
res_file = open(res_file_path, "w")
res_file.close()

创建res文件

# Learning rates and architecture-specific optimization parameters
arch_lst = get_all_archs(config)
lr = {}
auto_lr_annealing = {}
improvement_threshold = {}
halving_factor = {}
pt_files = {}

for arch in arch_lst:
    lr[arch] = expand_str_ep(config[arch]["arch_lr"], "float", N_ep, "|", "*")
    if len(config[arch]["arch_lr"].split("|")) > 1:
        auto_lr_annealing[arch] = False
    else:
        auto_lr_annealing[arch] = True
    improvement_threshold[arch] = float(config[arch]["arch_improvement_threshold"])
    halving_factor[arch] = float(config[arch]["arch_halving_factor"])
    pt_files[arch] = config[arch]["arch_pretrain_file"]

根据不同的框架,初始化不同的学习率

# If production, skip training and forward directly from last saved models
if is_production:
    ep = N_ep - 1
    N_ep = 0
    model_files = {}

    for arch in pt_files.keys():
        model_files[arch] = out_folder + "/exp_files/final_" + arch + ".pkl"
 for tr_data in tr_data_lst:

        # Compute the total number of chunks for each training epoch
        N_ck_tr = compute_n_chunks(out_folder, tr_data, ep, N_ep_str_format, "train")
        N_ck_str_format = "0" + str(max(math.ceil(np.log10(N_ck_tr)), 1)) + "d"
         # ***Epoch training***
        for ck in range(N_ck_tr):

首先进入train的for循环,计算了chunk的数量,进入chunk的循环,计算路径,更新learning rate

 	 	if not (os.path.exists(info_file)):

             print("Training %s chunk = %i / %i" % (tr_data, ck + 1, N_ck_tr))

                # getting the next chunk
             next_config_file = cfg_file_list[op_counter]

                # run chunk processing
            [data_name, data_set, data_end_index, fea_dict, lab_dict, arch_dict] = run_nn(
                    data_name,
                    data_set,
                    data_end_index,
                    fea_dict,
                    lab_dict,
                    arch_dict,
                    config_chunk_file,
                    processed_first,
                    next_config_file,
                )

                # update the first_processed variable
            processed_first = False

            if not (os.path.exists(info_file)):
                    sys.stderr.write(
                        "ERROR: training epoch %i, chunk %i not done! File %s does not exist.\nSee %s \n"
                        % (ep, ck, info_file, log_file)
                    )
                sys.exit(0)

如果这个chunk还没计算,使用run_nn计算,结果保存在一个大list里,更新pt_file(不知道是什么,可能是保存DNN的),删除之前的pkl file

            if do_validation_after_chunk(ck, N_ck_tr, config):
                if not _is_first_validation(ep,ck, N_ck_tr, config):
                    valid_peformance_dict_prev = valid_peformance_dict
                valid_peformance_dict = {}
                for valid_data in valid_data_lst:
                    N_ck_valid = compute_n_chunks(out_folder, valid_data, ep, N_ep_str_format, "valid")
                    N_ck_str_format_val = "0" + str(max(math.ceil(np.log10(N_ck_valid)), 1)) + "d"
                    for ck_val in range(N_ck_valid):

对每个chunks做validation
然后是计算error和loss

# --------FORWARD--------#
for forward_data in forward_data_lst:

    # Compute the number of chunks
    N_ck_forward = compute_n_chunks(out_folder, forward_data, ep, N_ep_str_format, "forward")
    N_ck_str_format = "0" + str(max(math.ceil(np.log10(N_ck_forward)), 1)) + "d"

    processes = list()
    info_files = list()
    for ck in range(N_ck_forward):

        if not is_production:
            print("Testing %s chunk = %i / %i" % (forward_data, ck + 1, N_ck_forward))
        else:
            print("Forwarding %s chunk = %i / %i" % (forward_data, ck + 1, N_ck_forward))
# --------DECODING--------#
dec_lst = glob.glob(out_folder + "/exp_files/*_to_decode.ark")

forward_data_lst = config["data_use"]["forward_with"].split(",")
forward_outs = config["forward"]["forward_out"].split(",")
forward_dec_outs = list(map(strtobool, config["forward"]["require_decoding"].split(",")))

进入decode 部分

for data in forward_data_lst:
    for k in range(len(forward_outs)):
        if forward_dec_outs[k]:

            print("Decoding %s output %s" % (data, forward_outs[k]))

            info_file = out_folder + "/exp_files/decoding_" + data + "_" + forward_outs[k] + ".info"

            # create decode config file
            config_dec_file = out_folder + "/decoding_" + data + "_" + forward_outs[k] + ".conf"
            config_dec = configparser.ConfigParser()
            config_dec.add_section("decoding")

            for dec_key in config["decoding"].keys():
                config_dec.set("decoding", dec_key, config["decoding"][dec_key])

            # add graph_dir, datadir, alidir
            lab_field = config[cfg_item2sec(config, "data_name", data)]["lab"]

            # Production case, we don't have labels
            if not is_production:
                pattern = "lab_folder=(.*)\nlab_opts=(.*)\nlab_count_file=(.*)\nlab_data_folder=(.*)\nlab_graph=(.*)"
                alidir = re.findall(pattern, lab_field)[0][0]
                config_dec.set("decoding", "alidir", os.path.abspath(alidir))

                datadir = re.findall(pattern, lab_field)[0][3]
                config_dec.set("decoding", "data", os.path.abspath(datadir))

                graphdir = re.findall(pattern, lab_field)[0][4]
                config_dec.set("decoding", "graphdir", os.path.abspath(graphdir))
            else:
                pattern = "lab_data_folder=(.*)\nlab_graph=(.*)"
                datadir = re.findall(pattern, lab_field)[0][0]
                config_dec.set("decoding", "data", os.path.abspath(datadir))

                graphdir = re.findall(pattern, lab_field)[0][1]
                config_dec.set("decoding", "graphdir", os.path.abspath(graphdir))

                # The ali dir is supposed to be in exp/model/ which is one level ahead of graphdir
                alidir = graphdir.split("/")[0 : len(graphdir.split("/")) - 1]
                alidir = "/".join(alidir)
                config_dec.set("decoding", "alidir", os.path.abspath(alidir))

            with open(config_dec_file, "w") as configfile:
                config_dec.write(configfile)

            out_folder = os.path.abspath(out_folder)
            files_dec = out_folder + "/exp_files/forward_" + data + "_ep*_ck*_" + forward_outs[k] + "_to_decode.ark"
            out_dec_folder = out_folder + "/decode_" + data + "_" + forward_outs[k]

            if not (os.path.exists(info_file)):

                # Run the decoder
                cmd_decode = (
                    cmd
                    + config["decoding"]["decoding_script_folder"]
                    + "/"
                    + config["decoding"]["decoding_script"]
                    + " "
                    + os.path.abspath(config_dec_file)
                    + " "
                    + out_dec_folder
                    + ' "'
                    + files_dec
                    + '"'
                )
                run_shell(cmd_decode, log_file)

                # remove ark files if needed
                if not forward_save_files[k]:
                    list_rem = glob.glob(files_dec)
                    for rem_ark in list_rem:
                        os.remove(rem_ark)

            # Print WER results and write info file
            cmd_res = "./check_res_dec.sh " + out_dec_folder
            wers = run_shell(cmd_res, log_file).decode("utf-8")
            res_file = open(res_file_path, "a")
            res_file.write("%s\n" % wers)
            print(wers)

建立一个decode的config文件,然后把graph,data,ali添加进去,利用正则表达式匹配参数,输出词错误率吗,写入文件

lattice:是单词顺序的另一种表示,是最大的可能性

Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐