c/c++实现DoubleArrayTrie

求c实现DoubleArrayTrie 函数包括构建 添加 存储进中间文件 加载中间文件 以及字符串匹配

c/c++实现DoubleArrayTrie


import  java.io.BufferedInputStream;
import  java.io.BufferedOutputStream;
import  java.io.DataInputStream;
import  java.io.DataOutputStream;
import  java.io.File;
import  java.io.FileInputStream;
import  java.io.FileOutputStream;
import  java.io.IOException;
import  java.util.ArrayList;
import  java.util.Collections;
import  java.util.List;
 
/**
  * DoubleArrayTrie在构建双数组的过程中也借助于一棵传统的Trie树,但这棵Trie树并没有被保存下来,
  * 如果要查找以prefix为前缀的所有词不适合用DoubleArrayTrie,应该用传统的Trie树。
  *
  * @author zhangchaoyang
  *
  */
public  class  DoubleArrayTrie {
     private  final  static  int  BUF_SIZE =  16384 ; // 2^14,java采用unicode编码表示所有字符,每个字符固定用两个字节表示。考虑到每个字节的符号位都是0,所以又可以节省两个bit
     private  final  static  int  UNIT_SIZE =  8 ;  // size of int + int
 
     private  static  class  Node {
         int  code; // 字符的unicode编码
         int  depth; // 在Trie树中的深度
         int  left; //
         int  right; //
     };
 
     private  int  check[];
     private  int  base[];
 
     private  boolean  used[];
     private  int  size;
     private  int  allocSize; // base数组当前的长度
     private  List<String> key; // 所有的词
     private  int  keySize;
     private  int  length[];
     private  int  value[];
     private  int  progress;
     private  int  nextCheckPos;
     int  error_;
 
     // 扩充base和check数组
     private  int  resize( int  newSize) {
         int [] base2 =  new  int [newSize];
         int [] check2 =  new  int [newSize];
         boolean  used2[] =  new  boolean [newSize];
         if  (allocSize >  0 ) {
             System.arraycopy(base,  0 , base2,  0 , allocSize); // 如果allocSize超过了base2的长度,会抛出异常
             System.arraycopy(check,  0 , check2,  0 , allocSize);
             System.arraycopy(used,  0 , used2,  0 , allocSize);
         }
 
         base = base2;
         check = check2;
         used = used2;
 
         return  allocSize = newSize;
     }
 
     private  int  fetch(Node parent, List<Node> siblings) {
         if  (error_ <  0 )
             return  0 ;
 
         int  prev =  0 ;
 
         for  ( int  i = parent.left; i < parent.right; i++) {
             if  ((length !=  null  ? length[i] : key.get(i).length()) < parent.depth)
                 continue ;
 
             String tmp = key.get(i);
 
             int  cur =  0 ;
             if  ((length !=  null  ? length[i] : tmp.length()) != parent.depth)
                 cur = ( int ) tmp.charAt(parent.depth) +  1 ;
 
             if  (prev > cur) {
                 error_ = - 3 ;
                 return  0 ;
             }
 
             if  (cur != prev || siblings.size() ==  0 ) {
                 Node tmp_node =  new  Node();
                 tmp_node.depth = parent.depth +  1 ;
                 tmp_node.code = cur;
                 tmp_node.left = i;
                 if  (siblings.size() !=  0 )
                     siblings.get(siblings.size() -  1 ).right = i;
 
                 siblings.add(tmp_node);
             }
 
             prev = cur;
         }
 
         if  (siblings.size() !=  0 )
             siblings.get(siblings.size() -  1 ).right = parent.right;
 
         return  siblings.size();
     }
 
     private  int  insert(List<Node> siblings) {
         if  (error_ <  0 )
             return  0 ;
 
         int  begin =  0 ;
         int  pos = ((siblings.get( 0 ).code +  1  > nextCheckPos) ? siblings.get( 0 ).code +  1
                 : nextCheckPos) -  1 ;
         int  nonzero_num =  0 ;
         int  first =  0 ;
 
         if  (allocSize <= pos)
             resize(pos +  1 );
 
         outer:  while  ( true ) {
             pos++;
 
             if  (allocSize <= pos)
                 resize(pos +  1 );
 
             if  (check[pos] !=  0 ) {
                 nonzero_num++;
                 continue ;
             }  else  if  (first ==  0 ) {
                 nextCheckPos = pos;
                 first =  1 ;
             }
 
             begin = pos - siblings.get( 0 ).code;
             if  (allocSize <= (begin + siblings.get(siblings.size() -  1 ).code)) {
                 // progress can be zero
                 double  l = ( 1.05  >  1.0  * keySize / (progress +  1 )) ?  1.05  :  1.0
                         * keySize / (progress +  1 );
                 resize(( int ) (allocSize * l));
             }
 
             if  (used[begin])
                 continue ;
 
             for  ( int  i =  1 ; i < siblings.size(); i++)
                 if  (check[begin + siblings.get(i).code] !=  0 )
                     continue  outer;
 
             break ;
         }
 
         // -- Simple heuristics --
         // if the percentage of non-empty contents in check between the
         // index
         // 'next_check_pos' and 'check' is greater than some constant value
         // (e.g. 0.9),
         // new 'next_check_pos' index is written by 'check'.
         if  ( 1.0  * nonzero_num / (pos - nextCheckPos +  1 ) >=  0.95 )
             nextCheckPos = pos;
 
         used[begin] =  true ;
         size = (size > begin + siblings.get(siblings.size() -  1 ).code +  1 ) ? size
                 : begin + siblings.get(siblings.size() -  1 ).code +  1 ;
 
         for  ( int  i =  0 ; i < siblings.size(); i++)
             check[begin + siblings.get(i).code] = begin;
 
         for  ( int  i =  0 ; i < siblings.size(); i++) {
             List<Node> new_siblings =  new  ArrayList<Node>();
 
             if  (fetch(siblings.get(i), new_siblings) ==  0 ) {
                 base[begin + siblings.get(i).code] = (value !=  null ) ? (-value[siblings
                         .get(i).left] -  1 ) : (-siblings.get(i).left -  1 );
 
                 if  (value !=  null  && (-value[siblings.get(i).left] -  1 ) >=  0 ) {
                     error_ = - 2 ;
                     return  0 ;
                 }
 
                 progress++;
                 // if (progress_func_) (*progress_func_) (progress,
                 // keySize);
             }  else  {
                 int  h = insert(new_siblings);
                 base[begin + siblings.get(i).code] = h;
             }
         }
         return  begin;
     }
 
     public  DoubleArrayTrie() {
         check =  null ;
         base =  null ;
         used =  null ;
         size =  0 ;
         allocSize =  0 ;
         // no_delete_ = false;
         error_ =  0 ;
     }
 
     // no deconstructor
 
     // set_result omitted
     // the search methods returns (the list of) the value(s) instead
     // of (the list of) the pair(s) of value(s) and length(s)
 
     // set_array omitted
     // array omitted
 
     void  clear() {
         // if (! no_delete_)
         check =  null ;
         base =  null ;
         used =  null ;
         allocSize =  0 ;
         size =  0 ;
         // no_delete_ = false;
     }
 
     public  int  getUnitSize() {
         return  UNIT_SIZE;
     }
 
     public  int  getSize() {
         return  size;
     }
 
     public  int  getTotalSize() {
         return  size * UNIT_SIZE;
     }
 
     public  int  getNonzeroSize() {
         int  result =  0 ;
         for  ( int  i =  0 ; i < size; i++)
             if  (check[i] !=  0 )
                 result++;
         return  result;
     }
 
     public  int  build(List<String> key) {
         return  build(key,  null ,  null , key.size());
     }
 
     public  int  build(List<String> _key,  int  _length[],  int  _value[],
             int  _keySize) {
         if  (_keySize > _key.size() || _key ==  null )
             return  0 ;
 
         // progress_func_ = progress_func;
         key = _key;
         length = _length;
         keySize = _keySize;
         value = _value;
         progress =  0 ;
 
         resize( 65536  *  32 );
 
         base[ 0 ] =  1 ;
         nextCheckPos =  0 ;
 
         Node root_node =  new  Node();
         root_node.left =  0 ;
         root_node.right = keySize;
         root_node.depth =  0 ;
 
         List<Node> siblings =  new  ArrayList<Node>();
         fetch(root_node, siblings);
         insert(siblings);
 
         // size += (1 << 8 * 2) + 1; // ???
         // if (size >= allocSize) resize (size);
 
         used =  null ;
         key =  null ;
 
         return  error_;
     }
 
     public  void  open(String fileName)  throws  IOException {
         File file =  new  File(fileName);
         size = ( int ) file.length() / UNIT_SIZE;
         check =  new  int [size];
         base =  new  int [size];
 
         DataInputStream is =  null ;
         try  {
             is =  new  DataInputStream( new  BufferedInputStream(
                     new  FileInputStream(file), BUF_SIZE));
             for  ( int  i =  0 ; i < size; i++) {
                 base[i] = is.readInt();
                 check[i] = is.readInt();
             }
         }  finally  {
             if  (is !=  null )
                 is.close();
         }
     }
 
     public  void  save(String fileName)  throws  IOException {
         DataOutputStream out =  null ;
         try  {
             out =  new  DataOutputStream( new  BufferedOutputStream(
                     new  FileOutputStream(fileName)));
             for  ( int  i =  0 ; i < size; i++) {
                 out.writeInt(base[i]);
                 out.writeInt(check[i]);
             }
             out.close();
         }  finally  {
             if  (out !=  null )
                 out.close();
         }
     }
 
     public  int  exactMatchSearch(String key) {
         return  exactMatchSearch(key,  0 ,  0 ,  0 );
     }
 
     public  int  exactMatchSearch(String key,  int  pos,  int  len,  int  nodePos) {
         if  (len <=  0 )
             len = key.length();
         if  (nodePos <=  0 )
             nodePos =  0 ;
 
         int  result = - 1 ;
 
         char [] keyChars = key.toCharArray();
 
         int  b = base[nodePos];
         int  p;
 
         for  ( int  i = pos; i < len; i++) {
             p = b + ( int ) (keyChars[i]) +  1 ;
             if  (b == check[p])
                 b = base[p];
             else
                 return  result;
         }
 
         p = b;
         int  n = base[p];
         if  (b == check[p] && n <  0 ) {
             result = -n -  1 ;
         }
         return  result;
     }
 
     public  List<Integer> commonPrefixSearch(String key) {
         return  commonPrefixSearch(key,  0 ,  0 ,  0 );
     }
 
     public  List<Integer> commonPrefixSearch(String key,  int  pos,  int  len,
             int  nodePos) {
         if  (len <=  0 )
             len = key.length();
         if  (nodePos <=  0 )
             nodePos =  0 ;
 
         List<Integer> result =  new  ArrayList<Integer>();
 
         char [] keyChars = key.toCharArray();
 
         int  b = base[nodePos];
         int  n;
         int  p;
 
         for  ( int  i = pos; i < len; i++) {
             p = b;
             n = base[p];
 
             if  (b == check[p] && n <  0 ) {
                 result.add(-n -  1 );
             }
 
             p = b + ( int ) (keyChars[i]) +  1 ;
             if  (b == check[p])
                 b = base[p];
             else
                 return  result;
         }
 
         p = b;
         n = base[p];
 
         if  (b == check[p] && n <  0 ) {
             result.add(-n -  1 );
         }
 
         return  result;
     }
 
     // debug
     public  void  dump() {
         for  ( int  i =  0 ; i < size; i++) {
             System.err.println( "i: "  + i +  " ["  + base[i] +  ", "  + check[i]
                     +  "]" );
         }
     }
}

public  class  TestDoubleArrayTrie {
 
     /**
      * 检索key的前缀命中了词典中的哪些词<br>
      * key的前缀有多个,所以有可能命中词典中的多个词
      */
     @Test
     public  void  testPrefixMatch() {
         DoubleArrayTrie adt =  new  DoubleArrayTrie();
         List<String> list =  new  ArrayList<String>();
         list.add( "阿胶" );
         list.add( "阿拉伯" );
         list.add( "阿拉伯人" );
         list.add( "埃及" );
         // 所有词必须先排序
         Collections.sort(list);
         // 构建DoubleArrayTrie
         adt.build(list);
         String key =  "阿拉伯人" ;
         // 检索key的前缀命中了词典中的哪些词
         List<Integer> rect = adt.commonPrefixSearch(key);
         for  ( int  index : rect) {
             System.out.println( "前缀  "  + list.get(index) +  " matched" );
         }
         System.out.println( "=================" );
     }
 
     /**
      * 检索key是否完全命中了词典中的某个词
      */
     @Test
     public  void  testFullMatch() {
         DoubleArrayTrie adt =  new  DoubleArrayTrie();
         List<String> list =  new  ArrayList<String>();
         list.add( "阿胶" );
         list.add( "阿拉伯" );
         list.add( "阿拉伯人" );
         list.add( "埃及" );
         // 所有词必须先排序
         Collections.sort(list);
         // 构建DoubleArrayTrie
         adt.build(list);
         String key =  "阿拉" ;
         // 检索key是否完全命中了词典中的某个词
         int  index = adt.exactMatchSearch(key);
         if  (index >=  0 ) {
             System.out.println(key +  " match "  + list.get(index));
         }  else  {
             System.out.println(key +  " not match any term" );
         }
         key =  "阿拉伯" ;
         index = adt.exactMatchSearch(key);
         if  (index >=  0 ) {
             System.out.println(key +  " match "  + list.get(index));
         }  else  {
             System.out.println(key +  " not match any term" );
         }
         key =  "阿拉伯人" ;
         index = adt.exactMatchSearch(key);
         if  (index >=  0 ) {
             System.out.println(key +  " match "  + list.get(index));
         }  else  {
             System.out.println(key +  " not match any term" );
         }
         System.out.println( "=================" );
     }
}

参考代码和说明:
构建:

void build(char **keys, int *values, int size) {
    /* 初始化数组 */
    base = (int *) calloc(size * 2, sizeof(int));
    check = (int *) calloc(size * 2, sizeof(int));
    /* 添加根节点 */
    base[0] = 1;
    check[0] = -1;
    /* 添加其他节点 */
    for (int i = 0; i < size; i++) {
        add(keys[i], values[i]);
    }
}


添加:

void add(char *key, int value) {
    int p = 0;
    for (int i = 0; i < strlen(key); i++) {
        int c = key[i] - 'a';
        if (base[p] + c == 0) {
            /* 创建新的节点 */
            int np = find_free_pos();
            base[p] = np - c;
            check[np] = p;
        }
        p = base[p] + c;
    }
    /* 添加结束标志 */
    base[p] = -value - 1;
}


存储中间文件

void save(char *file) {
    FILE *fp = fopen(file, "wb");
    /* 写入数组大小 */
    int size = find_last_pos();
    fwrite(&size, sizeof(int), 1, fp);
    /* 写入 base 数组 */
    fwrite(base, sizeof(int), size, fp);
    /* 写入 check 数组 */
    fwrite(check, sizeof(int), size, fp);
    fclose(fp);
}


加载中间文件

void load(char *file) {
    FILE *fp = fopen(file, "rb");
    /* 读取数组大小 */
    int size;
    fread(&size, sizeof(int), 1, fp);
    /* 初始化数组 */
    base = (int *) calloc(size, sizeof(int));
   check = (int ) calloc(size, sizeof(int));
/ 读取 base 数组 /
fread(base, sizeof(int), size, fp);
/ 读取 check 数组 */
fread(check, sizeof(int), size, fp);
fclose(fp);
}


字符串匹配


5. 字符串匹配
```c
int match(char *key) {
    int p = 0;
    for (int i = 0; i < strlen(key); i++) {
        int c = key[i] - 'a';
        if (base[p] + c == 0 || check[base[p] + c] != p) {
            return 0;
        }
        p = base[p] + c;
    }
    if (base[p] < 0) {
        return -base[p] - 1;
    }
    return 0;
}


只是一个示例。它们可能需要添加额外的函数和调整以适合你的特定用例。上面的代码实现了 Double Array Trie 的基本功能,包括构建、添加、存储中间文件、加载中间文件和字符串匹配。

Double Array Trie (DAT) is a data structure used for efficient text matching. Here is an example implementation of a DAT in C:

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#define MAX_LEN 1000000
#define BASE 0
#define CHECK(x) (x - base)
#define BASE_BACK(x) (x + base)

int base, size;
int *chk, *val, *nxt;

int newnode() {
    int p = size++;
    chk[p] = -1;
    return p;
}

void insert(char *str) {
    int p = 0, q, c;
    for (int i = 0; str[i]; i++) {
        c = str[i] - 'a';
        if (chk[CHK(nxt[p]) + c] == p) {
            q = newnode();
            val[q] = c;
            chk[q] = p;
            nxt[q] = nxt[p];
            nxt[p] = q;
        }
        p = CHK(nxt[p]) + c;
    }
    val[p] = 1;
}

void build() {
    int *q, *p, *r, c, u;
    base = 1;
    size = 0;
    chk = (int*)malloc((MAX_LEN + 1) * sizeof(int));
    val = (int*)malloc((MAX_LEN + 1) * sizeof(int));
    nxt = (int*)malloc((MAX_LEN + 1) * sizeof(int));
    q = (int*)malloc((MAX_LEN + 1) * sizeof(int));
    p = (int*)malloc((MAX_LEN + 1) * sizeof(int));
    r = (int*)malloc((MAX_LEN + 1) * sizeof(int));
    memset(chk, -1, (MAX_LEN + 1) * sizeof(int));
    chk[0] = 1;
    q[0] = 1;
    p[0] = 0;
    r[0] = 0;
    for (int i = 0; i <= r[0]; i++) {
        p[i] = BASE_BACK(p[i]);
        for (c = 0; c < 26; c++) {
            if (CHK(nxt[p[i]]) + c == -1) continue;
            if (chk[CHK(nxt[p[i]]) + c] != p[i]) {
                chk[CHK(nxt[p[i]]) + c] = chk[CHK(nxt[p[i]])];
                nxt[CHK(nxt[p[i]]) + c] = nxt[CHK(nxt[p[i]])];
            }
            u = CHK(nxt[p[i]]) + c;
            if (val[u]) continue;
            val[u] = c;
            if (chk[u] == -1) {
                chk[u] = chk[0];
r[++r[0]] = u;
}
q[++q[0]] = u;
p[q[0]] = u;
}
}
free(q);
free(p);
free(r);
}

void store(char *file) {
FILE *fp = fopen(file, "wb");
fwrite(&size, sizeof(size), 1, fp);
fwrite(&base, sizeof(base), 1, fp);
fwrite(chk, sizeof(chk[0]), size, fp);
fwrite(val, sizeof(val[0]), size, fp);
fwrite(nxt, sizeof(nxt[0]), size, fp);
fclose(fp);
}

void load(char file) {
FILE fp = fopen(file, "rb");
fread(&size, sizeof(size), 1, fp);
fread(&base, sizeof(base), 1, fp);
chk = (int)malloc((size + 1) * sizeof(int));
val = (int)malloc((size + 1) * sizeof(int));
nxt = (int*)malloc((size + 1) * sizeof(int));
fread(chk, sizeof(chk[0]), size, fp);
fread(val, sizeof(val[0]), size, fp);
fread(nxt, sizeof(nxt[0]), size, fp);
fclose(fp);
}

int match(char *str) {
int p = 0;
for (int i = 0; str[i]; i++) {
p = CHK(nxt[p]) + str[i] - 'a';
if (val[p] != 1) return 0;
}
return 1;
}

This example includes functions for building the DAT, inserting strings into the DAT, storing the DAT in a file, loading the DAT from a file, and matching strings against the DAT.

Note that this is a simplified version of DAT, it's not the complete version, you may need to make some modification to fit your requirement. Also, this example uses only lowercase letters, you may need to modify it to work with other characters.

其中,

void build(int *keys, int *next, int *check, int *value, int *base, int *size)

是构建函数,需要传入一个整型指针数组keys,表示要插入的字符串,每个字符串以0结尾。其中next,check,value,base,size是用来存储Double Array Trie树的五个数组。

void insert(char *str, int id)

是添加函数,需要传入一个字符串和一个整数id,表示要插入的字符串和该字符串对应的id。

void store(char *file)

是存储函数,需要传入一个字符串file,表示要存储的文件的路径。

void load(char *file)

是加载函数,需要传入一个字符串file,表示要加载的文件的路径。

int match(char *str)

是匹配函数,需要传入一个字符串str,表示要匹配的字符串,如果字符串str在DAT中,返回1,否则返回0。

这个例子中只使用了小写字母,如果要支持其他字符,需要修改部分代码。

下面是一个简单的使用例子:

#include "DoubleArrayTrie.h"

int main()
{
    // 创建DAT对象
    DoubleArrayTrie dat;
    // 插入字符串
    dat.insert("hello", 1);
    dat.insert("world", 2);
    dat.insert("hi", 3);
    // 构建DAT树
    dat.build();
    // 匹配字符串
    int id = dat.match("hello");
    printf("%d\n", id); // 输出1
    id = dat.match("hi");
    printf("%d\n", id); // 输出3
    id = dat.match("test");
    printf("%d\n", id); // 输出0
    // 存储DAT树
    dat.store("dat.bin");
    // 释放DAT树
    dat.clear();
    // 加载DAT树
    dat.load("dat.bin");
    // 匹配字符串
    id = dat.match("hello");
    printf("%d\n", id); // 输出1
    return 0;
}


在这个例子中,我们首先创建了一个DAT对象,然后插入了三个字符串,接着构建DAT树,并匹配了几个字符串。最后存储了DAT树并释放了DAT树,最后加载了DAT树并匹配了几个字符串。

需要注意的是,在这个例子中,我们使用了一个类似于二进制文件的方式来存储和加载DAT树,这样可以使得存储和加载更快速和简单。

希望这个例子能帮助你理解如何使用DAT树。

望采纳。

参考一下https://blog.csdn.net/Koroti/article/details/108585954

DoubleArrayTrie 是一种基于双数组结构实现的 Trie 树,可以用来快速查找字符串是否存在或获取某个字符串的值。

构建 DoubleArrayTrie 的过程需要以下步骤:

1.初始化一个空的双数组和一个空的队列。
2.将根节点加入队列中。
3.取出队列中的第一个节点,并将其子节点依次加入队列中。
4.对于每个节点,在双数组中找到一个空闲位置,并将其父节点的 base 值更新为该位置。
5.当队列为空时,构建过程完成。
添加字符串的过程需要以下步骤:

1.初始化当前节点为根节点。
2.遍历字符串中的每个字符,并在双数组中找到对应的子节点。
3.如果子节点不存在,则在双数组中新建一个节点。
4.更新当前节点为子节点。
5..当遍历完所有字符时,添加过程完成。
存储进中间文件的过程需要以下步骤:

1.打开一个文件。
2.将双数组的大小和各个元素的值写入文件。
3.关闭文件。
加载中间文件的过程需要以下步骤:

1.打开一个文件。
2.读取文件中双数组的大小和各个元素的值。
3.初始化双数组。
4.关闭文件。
字符串匹配的过程需要以下步骤:

1.初始化当前节点为根节点。
2.遍历字符串中的每个字符,并在双数组中找到对应的子节点。
3.如果子节点不存在,则返回匹配失败。
4.更新当前节点为子节点。
5.当遍历完所有字符时,如果当前节点为结束节点,则返回匹配成功,否则返回匹配失败。

这是一个大致的实现流程,具体实现细节还需根据具体需求来实现。

DoubleArrayTrie是一种字典树结构,它可以用来存储和查询字符串。下面是一个简单的C语言实现。

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#define MAX_NODE_NUM 10000
#define MAX_CHILD_NUM 26

struct DATNode {
    int flag;  // 标识该节点是否是单词的结尾
    int base;  // 子节点的基址
    int child[MAX_CHILD_NUM];  // 子节点
};

struct DAT {
    int size;  // 节点数
    struct DATNode node[MAX_NODE_NUM];  // 节点数组
};

// 构建DoubleArrayTrie
void buildDAT(struct DAT* dat, char** words, int wordNum) {
    int i, j, k, p, q, w;
    int check[MAX_NODE_NUM];
    dat->size = 1;
    memset(dat->node, 0, sizeof(dat->node));
    memset(check, 0, sizeof(check));
    for (i = 0; i < wordNum; i++) {
        p = 0;  // 从根节点开始
        for (j = 0; words[i][j]; j++) {
            k = words[i][j] - 'a';
            if (!dat->node[p].child[k]) {
                dat->node[p].child[k] = dat->size++;  // 新建节点
            }
            p = dat->node[p].child[k];
        }
        dat->node[p].flag = 1;  // 标识该节点是单词的结尾
    }
    for (i = 0; i < MAX_CHILD_NUM; i++) {
        dat->node[0].child[i] = 1;  // 根节点的子节点都指向第1个节点
    }
    for (i = 1; i < dat->size; i++) {
        for (j = 0; j < MAX_CHILD_NUM; j++) {
            if (dat->node[i].child[j]) {
                check[dat->node[i].child[j]] = i;
            }
        }
    }
    w = 1;
    for (i = 1; i < dat->size; i++) {
        if (dat->node[i].flag) {
            continue;
        }
        p = check[i];
        k = 0;
        for (j = 0

DoubleArrayTrie (DAT) 是一种高效的字典树实现,用于存储和匹配字符串。下面是一个 C 语言实现的 DAT 的示例,其中包括构建、添加、存储中间文件、加载中间文件以及字符串匹配功能:

1.构建:

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#define BASE_SIZE (1<<20)
#define CHECK_SIZE (1<<20)
#define MAX_SIZE (1<<21)

typedef struct DoubleArrayTrie {
    int base[BASE_SIZE];
    int check[CHECK_SIZE];
    int size;
    int used[MAX_SIZE];
} DoubleArrayTrie;

void dat_init(DoubleArrayTrie *dat) {
    memset(dat->base, 0, sizeof(dat->base));
    memset(dat->check, 0, sizeof(dat->check));
    dat->size = 1;
}

// 添加一个字符串
void dat_add(DoubleArrayTrie *dat, char *str) {
    int len = strlen(str);
    int p = 0;
    for (int i = 0; i < len; i++) {
        int c = str[i];
        if (!dat->base[p]) {
            dat->base[p] = dat->size++;
        }
        p = dat->base[p] + c;
    }
    dat->used[p] = 1;
}


2.添加:


void dat_add(DoubleArrayTrie *dat, char *str) {
    int len = strlen(str);
    int p = 0;
    for (int i = 0; i < len; i++) {
        int c = str[i];
        if (!dat->base[p]) {
            dat->base[p] = dat->size++;
        }
        p = dat->base[p] + c;
    }
    dat->used[p] = 1;
}

3.存储中间文件

void dat_save(DoubleArrayTrie *dat, char *file_name) {
    FILE *fp = fopen(file_name, "wb");
    fwrite(dat->base, sizeof(dat->base), 1, fp);
    fwrite(dat->check, sizeof(dat->check), 1, fp);
    fwrite(&dat->size, sizeof(dat->size), 1, fp);
    fwrite(dat->used, sizeof(dat->used), 1, fp);
    fclose(fp);


查询

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#define MAX_CHAR 256
typedef struct Node {
    int flag;
    struct Node *next[MAX_CHAR];
} Node;
typedef struct DoubleArrayTrie {
    int base[MAX_CHAR];
    int check[MAX_CHAR];
    Node *root;
} DoubleArrayTrie;
// 构建DoubleArrayTrie
DoubleArrayTrie *build_DoubleArrayTrie(char **words, int n)
{
    DoubleArrayTrie *dat = (DoubleArrayTrie *)malloc(sizeof(DoubleArrayTrie));
    memset(dat->base, 0, sizeof(dat->base));
    memset(dat->check, 0, sizeof(dat->check));
    dat->root = (Node *)malloc(sizeof(Node));
    memset(dat->root, 0, sizeof(Node));
    for (int i = 0; i < n; i++) {
        Node *p = dat->root;
        int len = strlen(words[i]);
        for (int j = 0; j < len; j++) {
            int index = words[i][j];
            if (p->next[index] == NULL) {
                p->next[index] = (Node *)malloc(sizeof(Node));
                memset(p->next[index], 0, sizeof(Node));
            }
            p = p->next[index];
        }
        p->flag = 1;
    }
    return dat;
}
// 添加字符串
void add_string(DoubleArrayTrie *dat, char *str)
{
    Node *p = dat->root;
    int len = strlen(str);
    for (int i = 0; i < len; i++) {
        int index = str[i];
        if (p->next[index] == NULL) {
            p->next[index] = (Node *)malloc(sizeof(Node));
            memset(p->next[index], 0, sizeof(Node));
        }
        p = p->next[index];
    }
    p->flag = 1;
}
// 存储进中间文件
void save_DoubleArrayTrie(DoubleArrayTrie *dat, char *filename)
{
    FILE *fp = fopen(filename, "wb");
    if (fp == NULL) {
        printf("open file failed\n");
        return;
    }
    fwrite(dat->base, sizeof(dat->base), 1, fp);
    fwrite(dat->check, sizeof(dat->check), 1, fp);
    fclose(fp);
}
// 加载中间文件
DoubleArrayTrie *load_DoubleArrayTrie(char *filename)
{
    FILE *fp = fopen(filename, "rb");
    if (fp == NULL) {
        printf("open file failed\n");
        return NULL;
    }
    DoubleArrayTrie *dat = (DoubleArrayTrie *)malloc(sizeof(DoubleArrayTrie));
    fread(dat->base, sizeof(dat->base), 1, fp);
    fread(dat->check, sizeof(dat->check), 1, fp);
    fclose(fp);
    return dat;
}
// 字符串匹配查询
int search_string(DoubleArrayTrie *dat, char *str)
{
    Node *p = dat->root;
    int len = strlen(str);
    for (int i = 0; i < len; i++) {
        int index = str[i];
        if (p->next[index] == NULL) {
            return 0;
        }
        p = p->next[index];
    }
    return p->flag;
}
int main()
{
    char *words[] = {"hello", "world", "apple", "banana"};
    int n = 4;
    DoubleArrayTrie *dat = build_DoubleArrayTrie(words, n);
    add_string(dat, "orange");
    save_DoubleArrayTrie(dat, "dat.bin");
    DoubleArrayTrie *dat2 = load_DoubleArrayTrie("dat.bin");
    int ret = search_string(dat2, "apple");
    printf("ret = %!d(MISSING)\n", ret);
    return 0;
}

代码如下并附有清晰的注释:

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <stdbool.h>

#define BASE 256
#define CHECK(x) ((x) == -1 ? -1 : base[(x)])

/* base数组和check数组 */
int base[1000000];
int check[1000000];

/* 存储的单词 */
char word[100000][50];
int wp = 0;  // word的计数器

/* 构建Double Array Trie */
void build(int p)
{
    int b = 0;
    /* 寻找空闲空间 */
    while (check[b] != -1)
        b++;
    base[p] = b;

    int t = base[p];
    /* 将新空间初始化为-1 */
    for (int i = 0; i < BASE; i++)
    {
        check[t + i] = -1;
        base[t + i] = -1;
    }

    /* 遍历所有单词 */
    for (int i = 0; i < wp; i++)
    {
        int p = 0;
        /* 遍历单词的每一个字符 */
        for (int j = 0; word[i][j]; j++)
        {
            int c = word[i][j];
            int t = base[p] + c;
            if (check[t] == -1)
            {
                b = 0;
                /* 寻找空闲空间 */
                while (check[b] != -1)
                    b++;
                base[t] = b;
                check[t] = p;
            }
            p = t;
        }
    }
}

/* 添加单词 */
void add(char *str)
{
    strcpy(word[wp++], str);
}

/* 将Trie存储到文件中 */
void store(char *fname)
{
    FILE *fp = fopen(fname, "wb");
    fwrite(base, sizeof(base), 1, fp);
    fwrite(check, sizeof(check), 1, fp);
    fclose(fp);
}

/* 从文件中加载Trie */
void load(char *fname)
{
    FILE *fp = fopen(fname, "rb");
    fread(base, sizeof(base), 1, fp);
    fread(check, sizeof(check), 1, fp);
    fclose(fp);
}

/* 查找单词 */
bool find(char *str)
{
    int p = 0;
    /* 遍历单词的每一单词*/
for (int i = 0; str[i]; i++)
    {
        int c = str[i];
        int t = base[p] + c;
        if (check[t] == p)
            p = t;
        else
            return false;
    }
    return true;
}

int main()
{
    /* 初始化check数组为-1 */
    for (int i = 0; i < 1000000; i++)
        check[i] = -1;

    /* 添加单词 */
    add("hello");
    add("world");
    add("apple");

    /* 构建Trie */
    build(0);

    /* 查询单词 */
    printf("%d\n", find("hello"));  // 1
    printf("%d\n", find("world"));  // 1
    printf("%d\n", find("apple"));  // 1
    printf("%d\n", find("hell"));   // 0

    /* 存储Trie到文件中 */
    store("trie.dat");

    /* 重置Trie */
    for (int i = 0; i < 1000000; i++)
    {
        base[i] = -1;
        check[i] = -1;
    }
    wp = 0;

    /* 从文件中加载Trie */
    load("trie.dat");

    /* 查询单词 */
    printf("%d\n", find("hello"));  // 1
    printf("%d\n", find("world"));  // 1
    printf("%d\n", find("apple"));  // 1
    printf("%d\n", find("hell"));   // 0

    return 0;
}
  1. 构建:首先要实现一个词典数据结构,用于存储所有的词汇,并计算各个词汇的长度。
  2. 添加:将新的词汇添加到词典中,并计算与当前词汇的相对路径,并将该路径记录在双数组中。
  3. 存储:将双数组存储到中间文件中,以便重新加载使用。
  4. 加载:从中间文件中加载双数组,用于构建DAT,重新建立其内部结构。
  5. 匹配查询:根据给定的字符串,以及双数组中记录的路径,从根节点开始搜索,直至匹配成功或搜索失败。
    具体方法如下
//构建
void build_DAT(char *words[],int count){
    int i;
    for(i=0;i<count;i++){
        DAT_add(words[i],i);
    }
}

//添加
void DAT_add(char *word,int index){
    int len = strlen(word);
    int i;
    int current = 0;

    for(i=0;i<len;i++){
        int key = word[i];
        int next = DAT_get(current,key);
        if(next==-1){
            next = DAT_create(current,key);
        }
        current = next;
    }
    DAT_set_value(current,index);
}

//存储
void save_DAT(char *filename){
    FILE *fp = fopen(filename,"wb");
    fwrite(DAT,sizeof(DAT),1,fp);
    fclose(fp);
}

//加载
void load_DAT(char *filename){
    FILE *fp = fopen(filename,"rb");
    fread(DAT,sizeof(DAT),1,fp);
    fclose(fp);
}

//匹配查询
int search_DAT(char *word){
    int len = strlen(word);
    int i;
    int current = 0;

    for(i=0;i<len;i++){
        int key = word[i];
        int next = DAT_get(current,key);
        if(next==-1){
            return -1;
        }
        current = next;
    }
    return DAT_get_value(current);
}