FP-growth算法代码和调度算法代码不理解

dis_threshold = 400

user_start_end_dict = {}  # { 用户: [ [初始地,目的地], [], []... ], ... }
user_start_dict = {}  # { 用户: [ [初始地], [], []... ], ... }
user_end_dict = {}  # { 用户: [ [目的地], [], []... ], ... }
start_end_dict = {}  # { 初始地: [ [目的地], [], []... ], ... }

dataset24 = []

train_num = 0
test_num = 0

print('开始读取数据')

with open('train.csv') as f:
    reader = csv.reader(f)

    for line in reader:

        if reader.line_num % 80000 == 0:
            print("%.2f" % (reader.line_num / 3090000 * 100), "%")

        if reader.line_num == 1:
            continue

        t = time.strptime(line[4], "%Y-%m-%d %H:%M:%S")

        if t[2] == 24:
            test_num += 1
            dataset24.append(line)
        else:
            train_num += 1
            add_dict(user_start_end_dict, line[1], [line[5], line[6]])
            add_dict(user_start_dict, line[1], [line[5]])
            add_dict(user_end_dict, line[1], [line[6]])
            add_dict(start_end_dict, line[5], [line[6]])

print('load finished', 'train set:', train_num, 'test set:', test_num)

print("开始测试...")

pre_num = 0  # 参与预测的数量
invalid_num = 0  # 无效用户
right_num = 0  # 预测正确的数量
des_num = 0  # 预测目的地总数

for i, line in enumerate(dataset24):
    if i % 10000 == 0:
        print("%.2f" % (i / 128000 * 100), "%")

    username, ori, des = line[1], line[5], line[6]

    log("\n\n测试", i)
    log(username, ori, des)

    fp_user_start_end = []
    fp_user_start = []
    fp_user_end = []
    fp_start_end = []

    if username not in user_start_end_dict and ori not in start_end_dict:
        invalid_num += 1
        continue

    pre_num += 1

    if username in user_start_end_dict:
        log("user_start_end_dict[username]", user_start_end_dict[username])
        log("user_start_dict[username]", user_start_dict[username])
        log("user_end_dict[username]", user_end_dict[username])

        fp_user_start_end = fp_growth.generate(user_start_end_dict[username], 1, 2)

        user_start_list = []
        user_end_list = []

        for j, val in enumerate(user_start_dict[username]):
            user_start_list.extend([[item] for item in geohash.expand(val[0])])

        for j, val in enumerate(user_end_dict[username]):
            user_end_list.extend([[item] for item in geohash.expand(val[0])])

        log("user_start_list", user_start_list)
        log("user_end_list", user_end_list)
        fp_user_start = fp_growth.generate(user_start_list, 2, 0)
        fp_user_end = fp_growth.generate(user_end_list, 2, 0)

    if ori in start_end_dict:
        log("start_end_dict[ori]", start_end_dict[ori])

        # start_end_list = []
        # for j, val in enumerate(start_end_dict[ori]):
        #     start_end_list.extend([[item] for item in geohash.expand(val[0])])

        fp_start_end = fp_growth.generate(start_end_dict[ori], 3, 0)

    log("fp_user_start_end", fp_user_start_end)
    log("fp_user_start:", fp_user_start)
    log("fp_user_end:", fp_user_end)
    log("fp_start_end", fp_start_end)

    result_tmp = set()

    # 对于[[初始地]], [[目的地]]这样的频繁项集,直接添加到结果里
    for j, item in enumerate(fp_user_start):
        result_tmp.add(list(item)[0])
    for j, item in enumerate(fp_user_end):
        result_tmp.add(list(item)[0])
    for j, item in enumerate(fp_start_end):
        result_tmp.add(list(item)[0])

    # 遍历[[初始地,目的地]]频繁项集,把处在出发地周围的一个作为起始地点,而远的作为预测结果.
    log("遍历[[初始地,目的地]]频繁项集")
    for j, item in enumerate(fp_user_start_end):
        item_list = list(item)
        log(item_list[0], item_list[1])

        dis1 = geohash.get_distance_geohash(item_list[0], ori)
        dis2 = geohash.get_distance_geohash(item_list[1], ori)
        log("dis1:", dis1, "dis2", dis2)

        if dis1 > dis_threshold and dis2 > dis_threshold:
            continue
        else:
            log("将判断出的目的地加入结果")
            if dis1 > dis2:
                result_tmp.update(geohash.expand(item_list[0]))
            else:
                result_tmp.update(geohash.expand(item_list[1]))

    # 判断结果:如果结果中有实际目的地,则认为预测成功
    log("Result: ", result_tmp)
    # result = set()
    # for value in result_tmp:
    #     result.update(geohash.expand(value))
    # log("Result expanded:", result)
    des_num += len(result_tmp)
    for r in result_tmp:
        if des in geohash.expand(r):
            right_num += 1
            log("接受")
            break

    log("----------------------------------------------------------")

print("预测总数", pre_num)
print("无效个数", invalid_num)
print("召回率", right_num / pre_num)
print("平均目的地个数", des_num / pre_num)

下面是调度算法的部分代码,可以说一下使用了什么原理吗

class Dispatch:

    def __init__(self, height, width,
                 dispatch_thresh,
                 trans_thresh, max_trans,
                 distance_thresh):
        """

        :param height: 区域高
        :param width: 区域宽
        :param dispatch_thresh: 调度阀值(一个区域单车abs(+-) `>= dispatch_thresh` 才会进行调度)  (保证成本)
        :param trans_thresh: 运输阀值(一辆货车最少调度多少单车 `>= trans_thresh`)  (保证成本)
        :param max_trans: 最大运输(一辆货车最大调度多少单车 `<= max_trans`)
        :param distance_thresh: 距离阀值(只调度 `<= distance_thresh` 的区域)  (保证成本)
        """
        self.height = height
        self.width = width
        self.dispatch_thresh = dispatch_thresh
        self.trans_thresh = trans_thresh
        self.max_trans = max_trans
        self.distance_thresh = distance_thresh
        # ---------------------
        self.time = None
        self.dispatch_num = 0  # 调度次数
        self.state = np.zeros((height, width), dtype=np.int16)
        self.logs = []

    def update(self, next_pred):
        """更新self.state"""
        self.state += next_pred

    def show(self, vlim=300, save_fname=None):
        z = self.state
        plt.figure(figsize=(10, 7))

        plt.xticks(np.arange(0, 33, 2), fontsize=11)
        plt.yticks(np.arange(0, 33, 2), fontsize=11)

        plt.imshow(z, origin="lower", vmin=-vlim, vmax=vlim, cmap="RdBu_r")
        plt.colorbar(shrink=0.9)
        if save_fname:
            plt.savefig(save_fname, dpi=200, bbox_inches='tight')
        plt.show()

    def dispatch(self, time=None):
        """对self.state进行调度.

        :param time: str"""
        self.time = time
        # 1. 查找出该时刻所有异常的点
        pos_arr = np.stack(np.nonzero(np.abs(self.state) >= self.dispatch_thresh), axis=-1)  # shape(NUM, 2)
        pos_arr = sorted(pos_arr, key=lambda pos: -abs(self.state[tuple(pos)]))  # -: 降序
        # 2. 调度
        for pos in pos_arr:
            pos = tuple(pos)
            if abs(self.state[pos]) < self.dispatch_thresh:  # 在其他点调度后,该点正常了
                continue
            self._dispatch_pos(pos)

    def print_logs(self, clear=False):
        """点坐标为 先x后y"""
        for log in self.logs:
            print(log)
        if clear:
            self.clear_logs()

    def clear_logs(self):
        self.logs = []

    def _dispatch_pos(self, start_pos):  # 算法algorithm
        """调度一个点. 算法: 广搜

        :param start_pos: tuple(y, x). 需要被调度的点
        """
        # --------------------------------------------- 数据结构
        state = self.state
        pri_queue = SortedList(key=lambda item: item[0])  # 优先级队列priority queue. item: tuple(distance, pos)
        dires = [(-1, 0), (0, -1), (0, 1), (1, 0)]  # 4个方向 directions  (dy, dx)
        is_visited = np.zeros_like(state, dtype=np.bool)
        # --------------------------------------------- 算法
        # 1. 广搜/A*
        pri_queue.add([0, start_pos])  # (距起始点距离, 位置)
        while len(pri_queue) != 0:  # 队列不为空
            # ----------- 当前点的处理
            dist, curr_pos = pri_queue.pop(0)  # 当前点current, 距离最短优先
            # 非初始点, 超运输阀值, 符号相反
            if curr_pos != start_pos and abs(state[curr_pos]) >= self.trans_thresh and \
                    1 * state[start_pos] * state[curr_pos] < 0:
                # 由于会出现增现象的区域,通常会继续出现增量的情况,所以允许 超前运输
                trans_num = np.clip(state[curr_pos], -self.max_trans, self.max_trans)  # 最多运输限制
                state[curr_pos] -= trans_num
                state[start_pos] += trans_num
                self.dispatch_num += 1
                log = "id: %06d, 时间: %s, %s -> %s, 运输车辆数: %d, 距离: %.4fkm." % \
                      (self.dispatch_num, self.time, (start_pos[1], start_pos[0]), (curr_pos[1], curr_pos[0]),
                       int(trans_num), dist)
                self.logs.append(log)
                if 1 * state[start_pos] * (state[start_pos] - trans_num) <= 0:  # 反号了. 调度完成
                    break
            # ---------- 拓展. 遍历当前点的4个方向
            for dy, dx in dires:
                pos = curr_pos[0] + dy, curr_pos[1] + dx
                # 1. 越界
                if pos[0] < 0 or pos[1] < 0 or pos[0] >= state.shape[0] or pos[1] >= state.shape[1]:
                    continue
                # 2. 是否被访问过
                if is_visited[pos]:
                    continue
                else:
                    is_visited[pos] = True
                # 3. 对超过距离阀值的去除
                dist_start = distance(start_pos[0] - pos[0], start_pos[1] - pos[1])  # curr_pos距开始点距离
                if dist_start > self.distance_thresh:  # 不过大于最短距离值,则continue
                    continue
                # 4. 插入pri_queue
                pri_queue.add([dist_start, pos])

 

你好,我是有问必答小助手。为了技术专家团更好地为您解答问题,烦请您补充下(1)问题背景详情,(2)您想解决的具体问题,(3)问题相关代码图片或者报错信息。便于技术专家团更好地理解问题,并给出解决方案。

您可以点击问题下方的【编辑】,进行补充修改问题。