【题解】Luogu P4407【[JSOI2009]电子字典】

背景

输入一组单词后,输入一组待查询字符串。对于每一个待查询字符串,若为单词输出$-1$,否则输出经过一次删除,或替换,或添加字符后,能得到的单词种数。

分析

由于这道题要快速检索字符串,因此想到用$hash$或者$trie$树,编者选用后者。
结合字典树,很容易想到一种方法,即暴力枚举所有变换能得到的字符串,然后在字典树中搜索,若存在$cnt++$.
但很明显存在重复经过某个前缀的操作。
对此我们可以用类似$dfs$的方法,自顶向下搜索待查询字符串,并略微修改一下查找函数$:$增加开始节点下标$P$,直接从$p$节点开始查找。
而对于搜到的每一个节点有三种分支:$1.$删除此处的字符,则从该节点开始查找是否存在与$s[i+1, n]$相同的字符串。$2.$替换该处字符,则从下一个节点$($下个节点由替换的字符有关$)$开始查找$s[i+1,n]$。$3.$增添一个字符,从下一个节点查找$s[i,n]$。

注意$:$

由于只能变换一次,节点分支后直接查询,不再分支。即只有原字符串经过的节点分支。

易错$:$

注意分支后查到的字符串可能相同,即经过不同的变换可能得到同一个单词,为了避免重复计算,应开一个额外的数组$vis$避免重复。

代码

#include<cstdio>
#include<cstring> 
using namespace std;
const int maxn = 200005;
int trie[maxn][27] , sum , cnt;
bool end[maxn]  , vis[maxn]; //防止 经过不同的变换 得到 同一个 单词 

inline void insert(char *s){
    int e = strlen(s) - 1,p = 0, i = 0;
    while(i <= e){
        int q = trie[p][s[i] - 'a'];
        if( q == 0)     q = ++sum , trie[p][s[i] - 'a'] = q;
        p = q , i++;
    }
    end[p] = 1;
}

inline bool search(int k ,char * s){ //k 开始节点 
    int e = strlen(s) - 1, p = k, i = 0;
    while( i <= e){
        int q = trie[p][s[i] - 'a'];
        if( q == 0)     return 0;
        else    p = q, i++;
    }
    if(end[p] && !vis[p]){
        vis[p] = 1;
        return 1;
    }
    else return 0;
}

inline int oysj(char * s){
    int e = strlen(s) - 1, p = 0 , i = 0;       cnt = 0; // -1
    while(i <= e){
        cnt += search( p , s + i + 1);// 删除字符s【i】 
        for(int j = 0;j <= 25;j++)// 在s【i】之前增添一个字符 
            if( trie[p][j] != 0)    cnt += search( trie[p][j] , s + i); 
        for(int j = 0;j <= 25;j++)//替换掉s【i】 
            if( trie[p][j] != 0)    cnt += search( trie[p][j] , s + i + 1);
        int q = trie[p][ s[i] - 'a'] ;
        if( q != 0)     p = q , i++;
        else break;
    }
    if( i == e+1)
    for(int j = 0;j <= 25;j++){ //分支时漏掉了末尾添加字符的情况 
        int next = trie[p][j];
        if( end[next] && (!vis[next]) )     cnt++ , vis[next] = 1;
    }
    return cnt;
}

int main(){
    int n , m;char s[25];
    scanf("%d%d",&n,&m);
    for(int i = 1;i <= n;i++)
        scanf("%s",s) , insert(s);
    for(int i = 1;i <= m;i++){
        scanf("%s",s);
        memset( vis , 0 ,sizeof(vis));
        int f = search(0 , s);
        if(f == 1)  printf("-1\n");
        else oysj(s) , printf("%d\n",cnt);
    }
    return 0;
}

加入对话

1条评论

电子邮件地址不会被公开。 必填项已用*标注