[Most-ai-contest] Multi-span簡易說明
闍怵羅
s2w81234於gmail.com
Thu 1月 2 16:09:18 CST 2020
各位好,
我是負責Multi-span Extraction的人-羅上堡。
本次的做法是全基於Rule-based去實作的,所以Step會有點多,可能有些也不是必要呈現的。
以下會附上大概的整體流程共分為16步驟,由於過久沒有寫類似這樣流程的東西,所以會附上一張極簡的流程圖,來補充說明他們之間的關係。
整體流程如下:
Step1:提取Passage與Question之文本和NER。
Step2:將P的特殊符號全部清除。
※《『?「》」』:~@#¥%……&*():]+...
Step3:將NER的Begin與End位置重算。
※因為Step2,Begin與End位置會有偏移錯誤。
Step4:創造BERT輸入矩陣:[CLS]*Q*[SEP]*P*[SEP]。
Step5:提取Question的最後一句。
Step6:依照Step5的結果,提取關鍵字眼以獲得應回答幾個答案,如果沒有則視為非指定數量題目。
Step7:將Step4的矩陣丟給BERT產生出結果。
Step8:依照Step7的結果產生top-k的Begin與End。
Step9:去top-k裡面尋找答案,同時檢查是否超過20的長度,如果超過則繼續取下一個top-k+1的結果,直到數量滿足或是沒有候選答案為止。
Step10:依照Step9所選出的所有答案,進行內含(Within)與交疊(Overrap)的答案處理。
※Within condition
Answer1: 今天是總統大選
Answer2:是總統大
Result:今天是總統大選
※Overrap condition
Answer1: 今天是總統大選
Answer2:總統大選的日子
Result:總統大選
※由於這部分的code有莫名不好處理的地方,所以在此琢磨的地方比較久。
Step11:尋找候選答案裡面,是否有『、』字眼,如果有執行Step12;如果沒有則執行Step13。
Step12-1:尋找擁有『、』字眼的答案,往後擴充到句號。,並依照jieba的斷詞結果,來取得、後面的答案。
Step12-2:如果遇到『等』字眼時,需要再往後延伸找到等後面的字詞,來延伸擴充答案。
Step13:再度檢查Within與Overrap的情況判斷,如果有發生,則執行類Step10的結果判斷後,執行Step15:如果沒有,則進行
Step14。
Step14:由於沒有遇到特殊情況,會將每個篩選後的答案,進行最簡單的Rule串接。
Step15:將選出來的答案,透過NER的資訊,去將有包含到該NER的部分字元全部擴充回來,讓答案更加完整。
※Example condition
Answer: 統大選
NER:總統
Result:總統大選
Step16:輸出最終結果。
謝謝。
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://www.iis.sinica.edu.tw/pipermail/most-ai-contest/attachments/20200102/68b587e7/attachment-0001.html>
-------------- next part --------------
import warnings
warnings.filterwarnings("ignore")
import os,re
import json,jieba
import numpy as np
import torch
from tqdm.autonotebook import tqdm
from itertools import combinations
from .bao_bert_lin import BertTokenizer,BertForQuestionAnswering
class Multi_Span_Extraction_Layer():
def __init__(self):
path = os.path.dirname(os.path.abspath(__file__))
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.BERT_model = BertForQuestionAnswering.from_pretrained(path).to(self.device).eval()
self.BERT_tokenizer = BertTokenizer.from_pretrained(path)
def trunc_pair(self,A, B, max_length=512):
while True:
total_length = len(A) + len(B)+3
if total_length <= max_length:
break
B.pop()
def create_data_PQ(self,P_token,Q_token,max_length=512):
Q_t = [x for x in Q_token[:]]
P_t = [x for x in P_token[:]]
self.trunc_pair(Q_t,P_t,512)
token = ["[CLS]"] + [i for i in Q_t] + ["[SEP]"] + [i for i in P_t] + ["[SEP]"]
seg_ids = [0] * (len(Q_t) + 2) + [1] * (len(P_t) + 1)
inp_ids,token = self.BERT_tokenizer.convert_tokens_to_ids(token)
inp_mask = [1] * len(inp_ids)
padding = [0] *(max_length - len(inp_ids))
data_feature = [inp_ids,inp_mask,seg_ids]
data_f = [torch.tensor(i,dtype=torch.long).unsqueeze(0).to(self.device) for i in data_feature]
return data_f,token
def remove_blank(self,input_str):
input_str = input_str.lstrip()
input_str = input_str.rstrip()
input_str = input_str.replace(',','。')
input_str = input_str.replace(',','。')
input_str = input_str.replace('!','。')
input_str = input_str.replace('!','。')
input_str = input_str.replace('?','。')
input_str = input_str.replace('?','。')
input_str = input_str.replace(';','。')
input_str = input_str.replace(';','。')
input_str = re.sub("[\s+\.\!\/_$%^*(+\"\']+|[+——《『?「》」』:~@#¥%……&*():]+",'',input_str)
return input_str
def remove_blank_Q(self,input_str):
input_str = input_str.lstrip()
input_str = input_str.rstrip()
return input_str
def check_within_overrap(self,a1_s,a1_e,a2_s,a2_e):
Rule2_flag=[0,0]
# Check within
if a1_s<=a2_s and a1_e>=a2_e:
Rule2_flag[0]=1
if a2_s<=a1_s and a2_e>=a1_e:
Rule2_flag[0]=2
if a1_s==a2_s and a1_e==a2_e:
Rule2_flag[0]=3
# Check overrap
if a1_s>a2_s and a1_e>a2_e and a2_e>a1_s:
Rule2_flag[1]=1
if a2_s>a1_s and a2_e>a1_e and a1_e>a2_s:
Rule2_flag[1]=2
return Rule2_flag[0],Rule2_flag[1]
def extract(self,data):
Text_P = data["DTEXT"]
P_IE = data["DIE"]
MSPE_json = []
for Data_Q in data["QUESTIONS"]:
Text_Q = Data_Q["QTEXT"]
Q_IE = Data_Q['QIE']
plain_p = self.remove_blank(Text_P)
speical_flag=0
try:
base_s=0
base_e=0
for i in P_IE['NER']:
ie_ner = i['string']
ie_s = i['char_b']
ie_e = i['char_e']
re_s = plain_p[base_e:].find(i['string'])+base_e
re_e = plain_p[base_e:].find(i['string'])+len(i['string'])+base_e
base_s = re_s
base_e = re_e
i['char_b']=re_s
i['char_e']=re_e
except:
speical_flag=1
plain_q = Text_Q.split('?')[0]
plain_q = plain_q.split('?')[0]
data_f,token = self.create_data_PQ(plain_p,plain_q,512)
q_len = len(plain_q)+2
plain_q_last = plain_q.split(',')[-1]
counter_answer=2
if plain_q_last.find('哪些')!=-1 :
counter_answer=50
elif plain_q_last.find('两')!=-1:
counter_answer=2
elif plain_q_last.find('三')!=-1:
counter_answer=3
elif plain_q_last.find('六')!=-1:
counter_answer=6
elif plain_q_last.find('五')!=-1:
counter_answer=5
elif plain_q_last.find('四')!=-1:
counter_answer=4
## Generate Top-k result
s,e = self.BERT_model(data_f[0],data_f[2],data_f[1])
s = s[:,q_len:]
e = e[:,q_len:]
arg_s = torch.topk(s,counter_answer)
arg_e = torch.topk(e,counter_answer)
## refix end > start condition
for idx in range(counter_answer):
if arg_s[1][0][idx]>arg_e[1][0][idx]:
temp_s = int(arg_s[1][0][idx])
temp_e = int(arg_e[1][0][idx])
arg_s[1][0][idx]=temp_e
arg_e[1][0][idx]=temp_s
## Stage1_preprocessing_span_answer
ans_pool=[]
temp_ans=[]
pre_num=0
err_num=0
for idx,(ss,ee) in enumerate(zip(arg_s[1][0],arg_e[1][0])):
if ss>ee:
continue
if ee>ss+20:
if plain_q_last.find('哪些')!=-1:
counter_answer=pre_num
if pre_num<2:
continue
else:
break
else:
if pre_num<counter_answer:
continue
else:
break
else:
arg_s[1][0][pre_num] = int(arg_s[1][0][idx])
arg_e[1][0][pre_num] = int(arg_e[1][0][idx])
pre_num+=1
ans = plain_p[ss:ee+1]
ans_pool.append(ans)
## Rule_2_2_span_within_span
flag=0
if len(ans_pool)>2:
list_pair = list(combinations(range(len(ans_pool)),2))
temp_within = []
temp_overrap = []
for i,j in list_pair:
fw,fo = self.check_within_overrap(arg_s[1][0][i],arg_e[1][0][i],arg_s[1][0][j],arg_e[1][0][j])
temp_within.append(fw)
temp_overrap.append(fo)
for iidx,within_flag in enumerate(temp_within):
if within_flag==1 or within_flag==3:
new_ans_pool =[]
flag=0
for iiidx,ans_pooll in enumerate(ans_pool):
if iiidx==list_pair[iidx][1]:
flag=1
else:
new_ans_pool.append(ans_pooll)
if flag==1:
try:
arg_s[1][0][iiidx] = int(arg_s[1][0][iiidx+1])
arg_e[1][0][iiidx] = int(arg_e[1][0][iiidx+1])
except:
'''
'''
ans_pool=new_ans_pool
break
elif within_flag==2:
new_ans_pool =[]
flag=0
for iiidx,ans_pooll in enumerate(ans_pool):
if iiidx==list_pair[iidx][0]:
flag=1
else:
new_ans_pool.append(ans_pooll)
if flag==1:
try:
arg_s[1][0][iiidx] = int(arg_s[1][0][iiidx+1])
arg_e[1][0][iiidx] = int(arg_e[1][0][iiidx+1])
except:
'''
'''
ans_pool=new_ans_pool
break
else:
'''
'''
if len(ans_pool)==2:
Rule2_flag = self.check_within_overrap(arg_s[1][0][0],arg_e[1][0][0],arg_s[1][0][1],arg_e[1][0][1])
else:
Rule2_flag = [0,0]
## Rule_1_only_one_span_with、Stopword
check_state = [x.find('、') for x in ans_pool]
check_state = [1 if num!=-1 else 0 for idx,num in enumerate(check_state)]
if sum(check_state)==1:
find_location = check_state.index(1)
chose_answer = ans_pool[find_location]
ca_sentence_s = arg_s[1][0][find_location]
ca_sentence_e = arg_e[1][0][find_location]
if ca_sentence_s> ca_sentence_e :
ca_sentence_s,ca_sentence_e = ca_sentence_e,ca_sentence_s
ca_sentence = plain_p[ca_sentence_s:].split('。')[0]
try:
ca_next = ca_sentence[ca_sentence.index('等'):]
except:
ca_next = ca_sentence[len(chose_answer):]
## Extra_rule_1
if len(ans_pool)==2:
other_s = arg_s[1][0][find_location^1]
other_e = arg_e[1][0][find_location^1]
if abs(other_e - ca_sentence_s)<=5:
chose_answer = plain_p[other_s:ca_sentence_e]
real_ans_pool =''
for idx,split_anwer in enumerate(chose_answer.split('、')):
if idx == len(chose_answer.split('、'))-2:
real_ans_pool += split_anwer+'與'
elif idx != len(chose_answer.split('、'))-1:
real_ans_pool += split_anwer+'、'
else:
real_ans_pool += split_anwer
ca_jieba = jieba.lcut(ca_next)
if ca_jieba[0]=='等':
real_ans = real_ans_pool+ca_jieba[0]+ca_jieba[1]
if real_ans_pool.find('等')!=-1:
real_ans = real_ans_pool[:real_ans_pool.index('等')]+ca_jieba[0]+ca_jieba[1]
else:
real_ans = real_ans_pool
elif Rule2_flag[0]!=0:
if Rule2_flag[0]==1 or Rule2_flag[0]==3:
with_s = arg_s[1][0][0]
with_e = arg_e[1][0][0]
elif Rule2_flag[0]==2:
with_s = arg_s[1][0][1]
with_e = arg_e[1][0][1]
real_ans=plain_p[with_s:with_e+1]
if len(real_ans)==2:
real_ans=real_ans[0]+'與'+real_ans[1]
elif Rule2_flag[1]!=0:
if Rule2_flag[1]==1:
over_s = arg_s[1][0][0]
over_e = arg_e[1][0][1]
else:
over_s = arg_s[1][0][1]
over_e = arg_e[1][0][0]
real_ans=plain_p[over_s:over_e+1]
else:
real_ans = ""
if len(ans_pool)==3:
if ans_pool[0]==ans_pool[1]:
ans_pool=[ans_pool[0],ans_pool[2]]
elif ans_pool[0]==ans_pool[2]:
ans_pool=[ans_pool[0],ans_pool[1]]
elif ans_pool[1]==ans_pool[2]:
ans_pool=[ans_pool[0],ans_pool[1]]
if len(ans_pool)==2:
if ans_pool[0]==ans_pool[1]:
ans_pool=[ans_pool[0]]
for idx,ans_p in enumerate(ans_pool):
if len(ans_pool)==2:
if idx==0:
real_ans+=ans_p+'和'
else:
real_ans+=ans_p
else:
if idx == len(ans_pool)-2:
real_ans+=ans_p+'及'
elif idx!=len(ans_pool)-1:
real_ans+=ans_p+'、'
else:
real_ans+=ans_p
find_top = plain_p.find(plain_p[arg_s[1][0][0]:arg_e[1][0][0]+1]+'、')
top_sentence = plain_p[arg_e[1][0][0]+1:].split('。')[0]
top_jieba = jieba.lcut(top_sentence)
try:
if top_jieba[0]=='与' or top_jieba[0]=='及':
real_ans = plain_p[arg_s[1][0][0]:arg_e[1][0][0]+1]+top_jieba[0]+top_jieba[1]
if sum(temp_within)>4:
real_ans=plain_p[arg_s[1][0][0]:arg_e[1][0][0]+1]
except:
'''
'''
flaggggg=0
if find_top!=-1:
f_t_sentence = plain_p[find_top:].split('。')[0]
f_t_e = plain_p.find(plain_p[find_top:].split('。')[0])
f_t_s = plain_p[:plain_p.find(plain_p[find_top:].split('。')[0])].split('。')[-1]
f_sentence = f_t_s+f_t_sentence
jieba_f_s = jieba.lcut(f_sentence)
if f_sentence.find('、')!=-1:
flaggggg=1
temp_idx=[]
for idx,aa in enumerate(jieba_f_s):
if aa=='、':
temp_idx.append(idx)
real_ans = ''.join(jieba_f_s[temp_idx[0]-1:temp_idx[-1]+2])
try:
temp2=[]
temp3=[]
for iidx,temp_data in enumerate(jieba_f_s[temp_idx[-1]:]):
#print(temp_data)
if temp_data=='等':
temp2=iidx
if temp_data=='和':
temp3=iidx
if temp2!=[]:
real_ans=''.join(jieba_f_s[temp_idx[0]-1:temp_idx[-1]+temp2])
else:
if temp3!=[]:
real_ans=''.join(jieba_f_s[temp_idx[0]-1:temp_idx[-1]+temp3+2])
if jieba_f_s[temp_idx[-1]+1]=='以及':
real_ans = ''.join(jieba_f_s[temp_idx[0]-1:temp_idx[-1]])
ans_s = plain_p.find(real_ans)
ans_e = ans_s+len(real_ans)
if speical_flag==0:
for i in P_IE['NER']:
ie_ner = i['string']
ie_s = i['char_b']
ie_e = i['char_e']
if ans_s>ie_s and ans_e>=ie_e and ie_e>=ans_s:
real_ans = plain_p[ie_s:ans_e]
ans_s=ie_s
if ie_s>ans_s and ie_e>ans_e and ans_e>=ie_s:
real_ans = plain_p[ans_s:ie_e]
ans_e=ie_e
except:
'''
'''
if speical_flag==0:
ans_s = plain_p.find(real_ans)
ans_e = ans_s+len(real_ans)
if ans_s == -1:
real_ans2 = real_ans.replace('與','、')
ans_s = plain_p.find(real_ans2)
ans_e = ans_s+len(real_ans2)
if ans_s !=-1:
try:
if jieba.lcut(plain_q_last[plain_q_last.find('哪些')+2:])[0]=='国家' or plain_q_last.find('国家')!=-1:
real_ans2=''
for i in P_IE['NER']:
if i['type']=='GPE' or i['type']=='DEMONYM' or i['type']=='COUNTRY':
ie_ner = i['string']
ie_s = i['char_b']
ie_e = i['char_e']
#print(ie_ner,ie_s,ie_e)
if ans_s>ie_s and ans_e>=ie_e and ie_e>=ans_s:
real_ans2+=i['string']+'、'
if ie_s>=ans_s and ans_e>=ie_e:
real_ans2+=i['string']+'、'
if ie_s>ans_s and ie_e>ans_e and ans_e>=ie_s:
real_ans2+=i['string']+'、'
real_ans2=real_ans2[:-1]
if real_ans2=='':
real_ans=real_ans
else:
real_ans=real_ans2
else:
for i in P_IE['NER']:
ie_ner = i['string']
ie_s = i['char_b']
ie_e = i['char_e']
if ans_s>ie_s and ans_e>=ie_e and ie_e>=ans_s:
real_ans = plain_p[ie_s:ans_e]
ans_s=ie_s
if ie_s>ans_s and ie_e>ans_e and ans_e>=ie_s:
real_ans = plain_p[ans_s:ie_e]
ans_e=ie_e
except:
for i in P_IE['NER']:
ie_ner = i['string']
ie_s = i['char_b']
ie_e = i['char_e']
if ans_s>ie_s and ans_e>=ie_e and ie_e>=ans_s:
real_ans = plain_p[ie_s:ans_e]
ans_s=ie_s
if ie_s>ans_s and ie_e>ans_e and ans_e>=ie_s:
real_ans = plain_p[ans_s:ie_e]
ans_e=ie_e
else:
real_ans=real_ans
MSPE_json_temp = [
{
"AMODULE": "Multi-Spans-Extraction",
"ATEXT": real_ans,
"score": 1.0, "start_score": 0.0, "end_score": 0.0
}
]
#print(MSPE_json)
MSPE_json.append(MSPE_json_temp)
return MSPE_json
if __name__ == "__main__":
MSE_layer= Multi_Span_Extraction_Layer()
fake_json = {"DTEXT": '苏轼(1037年1月8日-1101年8月24日),眉州眉山(今四川省眉山市)人,北宋时著名的文学家、政治家、艺术家、医学家。字子瞻,一字和仲,号东坡居士、铁冠道人。嘉佑二年进士,累官至端明殿学士兼翰林学士,礼部尚书。南宋理学方炽时,加赐谥号文忠,复追赠太师。有《东坡先生大全集》及《东坡乐府》词集传世,宋人王宗稷收其作品,编有《苏文忠公全集》。\n其散文、诗、词、赋均有成就,且善书法和绘画,是文学艺术史上的通才,也是公认韵文散文造诣皆比较杰出的大家。苏轼的散文为唐宋四家(韩愈、柳宗元、欧苏)之末,与唐代的古文运动发起者韩愈并称为「韩潮苏海」,也与欧阳修并称「欧苏」;更与父亲苏洵、弟苏辙合称「三苏」,父子三人,同列唐宋八大家。苏轼之诗与黄庭坚并称「苏黄」,又与陆游并称「苏陆」;其词「以诗入词」,首开词坛「豪放」一派,振作了晚唐、五代以来绮靡的西昆体余风。后世与南宋辛弃疾并称「苏辛」,惟苏轼故作豪放,其实清朗;其赋亦颇有名气,最知名者为贬谪期间借题发挥写的前后《赤壁赋》。宋代每逢科考常出现其文命题之考试,故当时学者曰:「苏文熟,吃羊肉、苏文生,嚼菜羹」。艺术方面,书法名列「苏、黄、米、蔡」北宋四大书法家(宋四家)之首;其画则开创了湖州画派;并在题画文学史上占有举足轻重的地位。',
"DIE":None,
"QUESTIONS":[{"QTEXT": "苏东坡曾担任过哪些职位?",
"QIE":None},
{"QTEXT": "苏东坡曾担任过哪些职位?",
"QIE":None}]}
refined = MSE_layer.extract(fake_json)
print(refined)
-------------- next part --------------
A non-text attachment was scrubbed...
Name: 圖片1.png
Type: image/png
Size: 159331 bytes
Desc: not available
URL: <http://www.iis.sinica.edu.tw/pipermail/most-ai-contest/attachments/20200102/68b587e7/attachment-0001.png>
More information about the Most-ai-contest
mailing list