[Most-ai-contest] Multi-span簡易說明

闍怵羅 s2w81234於gmail.com
Thu 1月 2 16:09:18 CST 2020


各位好,

我是負責Multi-span Extraction的人-羅上堡。

本次的做法是全基於Rule-based去實作的,所以Step會有點多,可能有些也不是必要呈現的。

以下會附上大概的整體流程共分為16步驟,由於過久沒有寫類似這樣流程的東西,所以會附上一張極簡的流程圖,來補充說明他們之間的關係。


整體流程如下:

Step1:提取Passage與Question之文本和NER。

Step2:將P的特殊符號全部清除。

          ※《『?「》」』:~@#¥%……&*():]+...

Step3:將NER的Begin與End位置重算。

          ※因為Step2,Begin與End位置會有偏移錯誤。

Step4:創造BERT輸入矩陣:[CLS]*Q*[SEP]*P*[SEP]。

Step5:提取Question的最後一句。

Step6:依照Step5的結果,提取關鍵字眼以獲得應回答幾個答案,如果沒有則視為非指定數量題目。

Step7:將Step4的矩陣丟給BERT產生出結果。

Step8:依照Step7的結果產生top-k的Begin與End。

Step9:去top-k裡面尋找答案,同時檢查是否超過20的長度,如果超過則繼續取下一個top-k+1的結果,直到數量滿足或是沒有候選答案為止。

Step10:依照Step9所選出的所有答案,進行內含(Within)與交疊(Overrap)的答案處理。

※Within condition

Answer1: 今天是總統大選

Answer2:是總統大

Result:今天是總統大選

※Overrap condition

Answer1: 今天是總統大選

Answer2:總統大選的日子

Result:總統大選

※由於這部分的code有莫名不好處理的地方,所以在此琢磨的地方比較久。

Step11:尋找候選答案裡面,是否有『、』字眼,如果有執行Step12;如果沒有則執行Step13。

Step12-1:尋找擁有『、』字眼的答案,往後擴充到句號。,並依照jieba的斷詞結果,來取得、後面的答案。

Step12-2:如果遇到『等』字眼時,需要再往後延伸找到等後面的字詞,來延伸擴充答案。

Step13:再度檢查Within與Overrap的情況判斷,如果有發生,則執行類Step10的結果判斷後,執行Step15:如果沒有,則進行
Step14。

Step14:由於沒有遇到特殊情況,會將每個篩選後的答案,進行最簡單的Rule串接。

Step15:將選出來的答案,透過NER的資訊,去將有包含到該NER的部分字元全部擴充回來,讓答案更加完整。

※Example condition

Answer: 統大選

NER:總統

Result:總統大選

Step16:輸出最終結果。


謝謝。
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://www.iis.sinica.edu.tw/pipermail/most-ai-contest/attachments/20200102/68b587e7/attachment-0001.html>
-------------- next part --------------
import warnings
warnings.filterwarnings("ignore")
import os,re
import json,jieba
import numpy as np
import torch
from tqdm.autonotebook import tqdm
from itertools import combinations
from .bao_bert_lin import BertTokenizer,BertForQuestionAnswering

class Multi_Span_Extraction_Layer():
    def __init__(self):
        path                = os.path.dirname(os.path.abspath(__file__))
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.BERT_model     = BertForQuestionAnswering.from_pretrained(path).to(self.device).eval()
        self.BERT_tokenizer = BertTokenizer.from_pretrained(path)
        
        
    def trunc_pair(self,A, B, max_length=512):
        while True:
            total_length = len(A) + len(B)+3
            if total_length <= max_length:
                break
            B.pop()
    def create_data_PQ(self,P_token,Q_token,max_length=512):
        Q_t = [x for x in Q_token[:]]
        P_t = [x for x in P_token[:]]
        self.trunc_pair(Q_t,P_t,512)
        token    = ["[CLS]"] + [i for i in Q_t] + ["[SEP]"] + [i for i in P_t] + ["[SEP]"] 
        seg_ids  = [0] * (len(Q_t) + 2) + [1] * (len(P_t) + 1)
        inp_ids,token  = self.BERT_tokenizer.convert_tokens_to_ids(token)

        inp_mask = [1] * len(inp_ids)
        padding  = [0] *(max_length - len(inp_ids))
        data_feature = [inp_ids,inp_mask,seg_ids]
        data_f = [torch.tensor(i,dtype=torch.long).unsqueeze(0).to(self.device) for i in data_feature]
        return data_f,token
    
    def remove_blank(self,input_str):
        input_str = input_str.lstrip()
        input_str = input_str.rstrip()
        input_str = input_str.replace(',','。')
        input_str = input_str.replace(',','。')
        input_str = input_str.replace('!','。')
        input_str = input_str.replace('!','。')
        input_str = input_str.replace('?','。')
        input_str = input_str.replace('?','。')
        input_str = input_str.replace(';','。')
        input_str = input_str.replace(';','。')
        input_str = re.sub("[\s+\.\!\/_$%^*(+\"\']+|[+——《『?「》」』:~@#¥%……&*():]+",'',input_str)
        return input_str
    def remove_blank_Q(self,input_str):
        input_str = input_str.lstrip()
        input_str = input_str.rstrip()
        return input_str
    
    def check_within_overrap(self,a1_s,a1_e,a2_s,a2_e):
        Rule2_flag=[0,0]
        # Check within
        if a1_s<=a2_s and a1_e>=a2_e:
            Rule2_flag[0]=1
        if a2_s<=a1_s and a2_e>=a1_e:
            Rule2_flag[0]=2
        if a1_s==a2_s and a1_e==a2_e:
            Rule2_flag[0]=3
        # Check overrap
        if a1_s>a2_s and a1_e>a2_e and a2_e>a1_s:
            Rule2_flag[1]=1
        if a2_s>a1_s and a2_e>a1_e and a1_e>a2_s:
            Rule2_flag[1]=2
        return Rule2_flag[0],Rule2_flag[1]

    
    def extract(self,data):
        Text_P = data["DTEXT"]
        P_IE   = data["DIE"]
        MSPE_json = []
        for Data_Q in data["QUESTIONS"]:
            Text_Q = Data_Q["QTEXT"]
            Q_IE   = Data_Q['QIE']
            plain_p = self.remove_blank(Text_P)
            
            speical_flag=0
            try:
                base_s=0
                base_e=0
                for i in P_IE['NER']:
                    ie_ner = i['string']
                    ie_s = i['char_b']
                    ie_e = i['char_e']
                    re_s = plain_p[base_e:].find(i['string'])+base_e
                    re_e = plain_p[base_e:].find(i['string'])+len(i['string'])+base_e
                    base_s = re_s
                    base_e = re_e

                    i['char_b']=re_s
                    i['char_e']=re_e
            except:
                speical_flag=1
            
            
            
            plain_q = Text_Q.split('?')[0]
            plain_q = plain_q.split('?')[0]
            data_f,token = self.create_data_PQ(plain_p,plain_q,512)
            q_len = len(plain_q)+2
            plain_q_last = plain_q.split(',')[-1]
            counter_answer=2
            if plain_q_last.find('哪些')!=-1 :
                counter_answer=50
            elif plain_q_last.find('两')!=-1:
                counter_answer=2
            elif plain_q_last.find('三')!=-1:
                counter_answer=3
            elif plain_q_last.find('六')!=-1:
                counter_answer=6
            elif plain_q_last.find('五')!=-1:
                counter_answer=5
            elif plain_q_last.find('四')!=-1:
                counter_answer=4
                
            ## Generate Top-k result
            s,e = self.BERT_model(data_f[0],data_f[2],data_f[1])
            s = s[:,q_len:]
            e = e[:,q_len:]
            arg_s = torch.topk(s,counter_answer)
            arg_e = torch.topk(e,counter_answer)

            ## refix end > start condition
            for idx in range(counter_answer):
                if arg_s[1][0][idx]>arg_e[1][0][idx]:
                    temp_s = int(arg_s[1][0][idx])
                    temp_e = int(arg_e[1][0][idx])
                    arg_s[1][0][idx]=temp_e
                    arg_e[1][0][idx]=temp_s
            
            ## Stage1_preprocessing_span_answer
            ans_pool=[]
            temp_ans=[]
            pre_num=0
            err_num=0
            for idx,(ss,ee) in enumerate(zip(arg_s[1][0],arg_e[1][0])):
                if ss>ee:
                    continue
                if ee>ss+20:
                    if plain_q_last.find('哪些')!=-1:
                        counter_answer=pre_num
                        if pre_num<2:
                            continue
                        else:
                            break
                    else:
                        if pre_num<counter_answer:
                            continue
                        else:
                            break
                else:
                    arg_s[1][0][pre_num] = int(arg_s[1][0][idx])
                    arg_e[1][0][pre_num] = int(arg_e[1][0][idx])
                    pre_num+=1
                ans = plain_p[ss:ee+1]
                ans_pool.append(ans)
            ## Rule_2_2_span_within_span
            flag=0
            if len(ans_pool)>2:
                list_pair = list(combinations(range(len(ans_pool)),2))
                temp_within = []
                temp_overrap = []
                for i,j in list_pair:
                    fw,fo = self.check_within_overrap(arg_s[1][0][i],arg_e[1][0][i],arg_s[1][0][j],arg_e[1][0][j])
                    temp_within.append(fw)
                    temp_overrap.append(fo)
                for iidx,within_flag in enumerate(temp_within):
                    if within_flag==1 or within_flag==3:
                        new_ans_pool =[]
                        flag=0
                        for iiidx,ans_pooll in enumerate(ans_pool):
                            if iiidx==list_pair[iidx][1]:
                                flag=1
                            else:
                                new_ans_pool.append(ans_pooll)
                            if flag==1:
                                try:
                                    arg_s[1][0][iiidx] = int(arg_s[1][0][iiidx+1])
                                    arg_e[1][0][iiidx] = int(arg_e[1][0][iiidx+1])
                                except:
                                    '''
                                    '''
                        ans_pool=new_ans_pool
                        break
                    elif within_flag==2:
                        new_ans_pool =[]
                        flag=0
                        for iiidx,ans_pooll in enumerate(ans_pool):
                            if iiidx==list_pair[iidx][0]:
                                flag=1
                            else:
                                new_ans_pool.append(ans_pooll)
                            if flag==1:
                                try:
                                    arg_s[1][0][iiidx] = int(arg_s[1][0][iiidx+1])
                                    arg_e[1][0][iiidx] = int(arg_e[1][0][iiidx+1])
                                except:
                                    '''
                                    '''
                        ans_pool=new_ans_pool
                        break
                    else:
                        '''
                        '''
            if len(ans_pool)==2:
                Rule2_flag = self.check_within_overrap(arg_s[1][0][0],arg_e[1][0][0],arg_s[1][0][1],arg_e[1][0][1])
            else:
                Rule2_flag = [0,0]
            ## Rule_1_only_one_span_with、Stopword
            check_state = [x.find('、') for x in ans_pool]
            check_state = [1 if num!=-1 else 0 for idx,num in enumerate(check_state)]
            if sum(check_state)==1:
                find_location = check_state.index(1)
                chose_answer = ans_pool[find_location]
                ca_sentence_s = arg_s[1][0][find_location]
                ca_sentence_e = arg_e[1][0][find_location]
                if ca_sentence_s> ca_sentence_e :
                    ca_sentence_s,ca_sentence_e = ca_sentence_e,ca_sentence_s
                ca_sentence = plain_p[ca_sentence_s:].split('。')[0]
                try:
                    ca_next = ca_sentence[ca_sentence.index('等'):]
                except:
                    ca_next = ca_sentence[len(chose_answer):]
                ## Extra_rule_1
                if len(ans_pool)==2:
                    other_s = arg_s[1][0][find_location^1]
                    other_e = arg_e[1][0][find_location^1]
                    if abs(other_e - ca_sentence_s)<=5:
                        chose_answer = plain_p[other_s:ca_sentence_e]
                real_ans_pool =''
                for idx,split_anwer in enumerate(chose_answer.split('、')):
                    if idx == len(chose_answer.split('、'))-2:
                        real_ans_pool += split_anwer+'與'
                    elif idx != len(chose_answer.split('、'))-1:
                        real_ans_pool += split_anwer+'、'
                    else:
                        real_ans_pool += split_anwer
                ca_jieba = jieba.lcut(ca_next)
                if ca_jieba[0]=='等':
                    real_ans = real_ans_pool+ca_jieba[0]+ca_jieba[1]
                    if real_ans_pool.find('等')!=-1:
                        real_ans = real_ans_pool[:real_ans_pool.index('等')]+ca_jieba[0]+ca_jieba[1]
                else:
                    real_ans = real_ans_pool
            elif Rule2_flag[0]!=0:
                if Rule2_flag[0]==1 or Rule2_flag[0]==3:
                    with_s = arg_s[1][0][0]
                    with_e = arg_e[1][0][0]
                elif Rule2_flag[0]==2:
                    with_s = arg_s[1][0][1]
                    with_e = arg_e[1][0][1]
                real_ans=plain_p[with_s:with_e+1]
                if len(real_ans)==2:
                    real_ans=real_ans[0]+'與'+real_ans[1]
            elif Rule2_flag[1]!=0:
                if Rule2_flag[1]==1:
                    over_s = arg_s[1][0][0]
                    over_e = arg_e[1][0][1]
                else:
                    over_s = arg_s[1][0][1]
                    over_e = arg_e[1][0][0]
                real_ans=plain_p[over_s:over_e+1]
            else:
                real_ans = ""
                if len(ans_pool)==3:
                    if ans_pool[0]==ans_pool[1]:
                        ans_pool=[ans_pool[0],ans_pool[2]]
                    elif ans_pool[0]==ans_pool[2]:
                        ans_pool=[ans_pool[0],ans_pool[1]]
                    elif ans_pool[1]==ans_pool[2]:
                        ans_pool=[ans_pool[0],ans_pool[1]]
                if len(ans_pool)==2:
                    if ans_pool[0]==ans_pool[1]:
                        ans_pool=[ans_pool[0]]
                for idx,ans_p in enumerate(ans_pool):
                    if len(ans_pool)==2:
                        if idx==0:
                            real_ans+=ans_p+'和'
                        else:
                            real_ans+=ans_p
                    else:
                        if idx == len(ans_pool)-2:
                            real_ans+=ans_p+'及'
                        elif idx!=len(ans_pool)-1:
                            real_ans+=ans_p+'、'
                        else:
                            real_ans+=ans_p
            find_top = plain_p.find(plain_p[arg_s[1][0][0]:arg_e[1][0][0]+1]+'、')
            top_sentence = plain_p[arg_e[1][0][0]+1:].split('。')[0]
            top_jieba = jieba.lcut(top_sentence)
            try:
                if top_jieba[0]=='与' or top_jieba[0]=='及':
                    real_ans = plain_p[arg_s[1][0][0]:arg_e[1][0][0]+1]+top_jieba[0]+top_jieba[1]
                if sum(temp_within)>4:
                    real_ans=plain_p[arg_s[1][0][0]:arg_e[1][0][0]+1]
            except:
                '''
                '''
            flaggggg=0
            if find_top!=-1:
                f_t_sentence = plain_p[find_top:].split('。')[0]
                f_t_e = plain_p.find(plain_p[find_top:].split('。')[0])
                f_t_s = plain_p[:plain_p.find(plain_p[find_top:].split('。')[0])].split('。')[-1]
                f_sentence = f_t_s+f_t_sentence
                jieba_f_s = jieba.lcut(f_sentence)
                if f_sentence.find('、')!=-1:
                    flaggggg=1
                    temp_idx=[]
                    for idx,aa in enumerate(jieba_f_s):
                        if aa=='、':
                            temp_idx.append(idx)
                    real_ans = ''.join(jieba_f_s[temp_idx[0]-1:temp_idx[-1]+2])
                    try:
                        temp2=[]
                        temp3=[]
                        for iidx,temp_data in enumerate(jieba_f_s[temp_idx[-1]:]):
                            #print(temp_data)
                            if temp_data=='等':
                                temp2=iidx
                            if temp_data=='和':
                                temp3=iidx
                        if temp2!=[]:
                            real_ans=''.join(jieba_f_s[temp_idx[0]-1:temp_idx[-1]+temp2])
                        else:
                            if temp3!=[]:
                                real_ans=''.join(jieba_f_s[temp_idx[0]-1:temp_idx[-1]+temp3+2])
                        if jieba_f_s[temp_idx[-1]+1]=='以及':
                            real_ans = ''.join(jieba_f_s[temp_idx[0]-1:temp_idx[-1]])
                        ans_s = plain_p.find(real_ans)
                        ans_e = ans_s+len(real_ans)
                        if speical_flag==0:
                            for i in P_IE['NER']:
                                ie_ner = i['string']
                                ie_s = i['char_b']
                                ie_e = i['char_e']

                                if ans_s>ie_s and ans_e>=ie_e and ie_e>=ans_s:
                                    real_ans = plain_p[ie_s:ans_e]
                                    ans_s=ie_s
                                if ie_s>ans_s and ie_e>ans_e and ans_e>=ie_s:
                                    real_ans = plain_p[ans_s:ie_e]
                                    ans_e=ie_e
                            
                    except:
                        '''
                        '''
            if speical_flag==0:
                
                ans_s = plain_p.find(real_ans)
                ans_e = ans_s+len(real_ans)
                if ans_s == -1:
                    real_ans2 = real_ans.replace('與','、')
                    ans_s = plain_p.find(real_ans2)
                    ans_e = ans_s+len(real_ans2)
                if ans_s !=-1:
                    try:
                        if jieba.lcut(plain_q_last[plain_q_last.find('哪些')+2:])[0]=='国家' or plain_q_last.find('国家')!=-1:
                            real_ans2=''
                            for i in P_IE['NER']:
                                if i['type']=='GPE' or i['type']=='DEMONYM' or i['type']=='COUNTRY':
                                    ie_ner = i['string']
                                    ie_s = i['char_b']
                                    ie_e = i['char_e']
                                    #print(ie_ner,ie_s,ie_e)
                                    if ans_s>ie_s and ans_e>=ie_e and ie_e>=ans_s:
                                        real_ans2+=i['string']+'、'

                                    if ie_s>=ans_s and ans_e>=ie_e:
                                        real_ans2+=i['string']+'、'
                                    if ie_s>ans_s and ie_e>ans_e and ans_e>=ie_s:
                                        real_ans2+=i['string']+'、'
                            real_ans2=real_ans2[:-1]
                            if real_ans2=='':
                                real_ans=real_ans
                            else:
                                real_ans=real_ans2
                        else:
                            for i in P_IE['NER']:
                                ie_ner = i['string']
                                ie_s = i['char_b']
                                ie_e = i['char_e']
                                if ans_s>ie_s and ans_e>=ie_e and ie_e>=ans_s:
                                    real_ans = plain_p[ie_s:ans_e]
                                    ans_s=ie_s
                                if ie_s>ans_s and ie_e>ans_e and ans_e>=ie_s:
                                    real_ans = plain_p[ans_s:ie_e]
                                    ans_e=ie_e
                    except:
                        for i in P_IE['NER']:
                            ie_ner = i['string']
                            ie_s = i['char_b']
                            ie_e = i['char_e']
                            if ans_s>ie_s and ans_e>=ie_e and ie_e>=ans_s:
                                real_ans = plain_p[ie_s:ans_e]
                                ans_s=ie_s
                            if ie_s>ans_s and ie_e>ans_e and ans_e>=ie_s:
                                real_ans = plain_p[ans_s:ie_e]
                                ans_e=ie_e
                else:
                    real_ans=real_ans
            MSPE_json_temp = [
                {
                    "AMODULE": "Multi-Spans-Extraction",
                    "ATEXT": real_ans,
                    "score": 1.0, "start_score": 0.0, "end_score": 0.0
                }
            ]
            #print(MSPE_json)
            MSPE_json.append(MSPE_json_temp)
        return MSPE_json



if __name__ == "__main__":
    MSE_layer= Multi_Span_Extraction_Layer()


    fake_json = {"DTEXT": '苏轼(1037年1月8日-1101年8月24日),眉州眉山(今四川省眉山市)人,北宋时著名的文学家、政治家、艺术家、医学家。字子瞻,一字和仲,号东坡居士、铁冠道人。嘉佑二年进士,累官至端明殿学士兼翰林学士,礼部尚书。南宋理学方炽时,加赐谥号文忠,复追赠太师。有《东坡先生大全集》及《东坡乐府》词集传世,宋人王宗稷收其作品,编有《苏文忠公全集》。\n其散文、诗、词、赋均有成就,且善书法和绘画,是文学艺术史上的通才,也是公认韵文散文造诣皆比较杰出的大家。苏轼的散文为唐宋四家(韩愈、柳宗元、欧苏)之末,与唐代的古文运动发起者韩愈并称为「韩潮苏海」,也与欧阳修并称「欧苏」;更与父亲苏洵、弟苏辙合称「三苏」,父子三人,同列唐宋八大家。苏轼之诗与黄庭坚并称「苏黄」,又与陆游并称「苏陆」;其词「以诗入词」,首开词坛「豪放」一派,振作了晚唐、五代以来绮靡的西昆体余风。后世与南宋辛弃疾并称「苏辛」,惟苏轼故作豪放,其实清朗;其赋亦颇有名气,最知名者为贬谪期间借题发挥写的前后《赤壁赋》。宋代每逢科考常出现其文命题之考试,故当时学者曰:「苏文熟,吃羊肉、苏文生,嚼菜羹」。艺术方面,书法名列「苏、黄、米、蔡」北宋四大书法家(宋四家)之首;其画则开创了湖州画派;并在题画文学史上占有举足轻重的地位。',
                 "DIE":None,
                 "QUESTIONS":[{"QTEXT": "苏东坡曾担任过哪些职位?",
                               "QIE":None},
                              {"QTEXT": "苏东坡曾担任过哪些职位?",
                               "QIE":None}]}

    refined = MSE_layer.extract(fake_json)        
    print(refined)
-------------- next part --------------
A non-text attachment was scrubbed...
Name: 圖片1.png
Type: image/png
Size: 159331 bytes
Desc: not available
URL: <http://www.iis.sinica.edu.tw/pipermail/most-ai-contest/attachments/20200102/68b587e7/attachment-0001.png>


More information about the Most-ai-contest mailing list