dis_threshold = 400
user_start_end_dict = {} # { 用户: [ [初始地,目的地], [], []... ], ... }
user_start_dict = {} # { 用户: [ [初始地], [], []... ], ... }
user_end_dict = {} # { 用户: [ [目的地], [], []... ], ... }
start_end_dict = {} # { 初始地: [ [目的地], [], []... ], ... }
dataset24 = []
train_num = 0
test_num = 0
print('开始读取数据')
with open('train.csv') as f:
reader = csv.reader(f)
for line in reader:
if reader.line_num % 80000 == 0:
print("%.2f" % (reader.line_num / 3090000 * 100), "%")
if reader.line_num == 1:
continue
t = time.strptime(line[4], "%Y-%m-%d %H:%M:%S")
if t[2] == 24:
test_num += 1
dataset24.append(line)
else:
train_num += 1
add_dict(user_start_end_dict, line[1], [line[5], line[6]])
add_dict(user_start_dict, line[1], [line[5]])
add_dict(user_end_dict, line[1], [line[6]])
add_dict(start_end_dict, line[5], [line[6]])
print('load finished', 'train set:', train_num, 'test set:', test_num)
print("开始测试...")
pre_num = 0 # 参与预测的数量
invalid_num = 0 # 无效用户
right_num = 0 # 预测正确的数量
des_num = 0 # 预测目的地总数
for i, line in enumerate(dataset24):
if i % 10000 == 0:
print("%.2f" % (i / 128000 * 100), "%")
username, ori, des = line[1], line[5], line[6]
log("\n\n测试", i)
log(username, ori, des)
fp_user_start_end = []
fp_user_start = []
fp_user_end = []
fp_start_end = []
if username not in user_start_end_dict and ori not in start_end_dict:
invalid_num += 1
continue
pre_num += 1
if username in user_start_end_dict:
log("user_start_end_dict[username]", user_start_end_dict[username])
log("user_start_dict[username]", user_start_dict[username])
log("user_end_dict[username]", user_end_dict[username])
fp_user_start_end = fp_growth.generate(user_start_end_dict[username], 1, 2)
user_start_list = []
user_end_list = []
for j, val in enumerate(user_start_dict[username]):
user_start_list.extend([[item] for item in geohash.expand(val[0])])
for j, val in enumerate(user_end_dict[username]):
user_end_list.extend([[item] for item in geohash.expand(val[0])])
log("user_start_list", user_start_list)
log("user_end_list", user_end_list)
fp_user_start = fp_growth.generate(user_start_list, 2, 0)
fp_user_end = fp_growth.generate(user_end_list, 2, 0)
if ori in start_end_dict:
log("start_end_dict[ori]", start_end_dict[ori])
# start_end_list = []
# for j, val in enumerate(start_end_dict[ori]):
# start_end_list.extend([[item] for item in geohash.expand(val[0])])
fp_start_end = fp_growth.generate(start_end_dict[ori], 3, 0)
log("fp_user_start_end", fp_user_start_end)
log("fp_user_start:", fp_user_start)
log("fp_user_end:", fp_user_end)
log("fp_start_end", fp_start_end)
result_tmp = set()
# 对于[[初始地]], [[目的地]]这样的频繁项集,直接添加到结果里
for j, item in enumerate(fp_user_start):
result_tmp.add(list(item)[0])
for j, item in enumerate(fp_user_end):
result_tmp.add(list(item)[0])
for j, item in enumerate(fp_start_end):
result_tmp.add(list(item)[0])
# 遍历[[初始地,目的地]]频繁项集,把处在出发地周围的一个作为起始地点,而远的作为预测结果.
log("遍历[[初始地,目的地]]频繁项集")
for j, item in enumerate(fp_user_start_end):
item_list = list(item)
log(item_list[0], item_list[1])
dis1 = geohash.get_distance_geohash(item_list[0], ori)
dis2 = geohash.get_distance_geohash(item_list[1], ori)
log("dis1:", dis1, "dis2", dis2)
if dis1 > dis_threshold and dis2 > dis_threshold:
continue
else:
log("将判断出的目的地加入结果")
if dis1 > dis2:
result_tmp.update(geohash.expand(item_list[0]))
else:
result_tmp.update(geohash.expand(item_list[1]))
# 判断结果:如果结果中有实际目的地,则认为预测成功
log("Result: ", result_tmp)
# result = set()
# for value in result_tmp:
# result.update(geohash.expand(value))
# log("Result expanded:", result)
des_num += len(result_tmp)
for r in result_tmp:
if des in geohash.expand(r):
right_num += 1
log("接受")
break
log("----------------------------------------------------------")
print("预测总数", pre_num)
print("无效个数", invalid_num)
print("召回率", right_num / pre_num)
print("平均目的地个数", des_num / pre_num)
下面是调度算法的部分代码,可以说一下使用了什么原理吗
class Dispatch:
def __init__(self, height, width,
dispatch_thresh,
trans_thresh, max_trans,
distance_thresh):
"""
:param height: 区域高
:param width: 区域宽
:param dispatch_thresh: 调度阀值(一个区域单车abs(+-) `>= dispatch_thresh` 才会进行调度) (保证成本)
:param trans_thresh: 运输阀值(一辆货车最少调度多少单车 `>= trans_thresh`) (保证成本)
:param max_trans: 最大运输(一辆货车最大调度多少单车 `<= max_trans`)
:param distance_thresh: 距离阀值(只调度 `<= distance_thresh` 的区域) (保证成本)
"""
self.height = height
self.width = width
self.dispatch_thresh = dispatch_thresh
self.trans_thresh = trans_thresh
self.max_trans = max_trans
self.distance_thresh = distance_thresh
# ---------------------
self.time = None
self.dispatch_num = 0 # 调度次数
self.state = np.zeros((height, width), dtype=np.int16)
self.logs = []
def update(self, next_pred):
"""更新self.state"""
self.state += next_pred
def show(self, vlim=300, save_fname=None):
z = self.state
plt.figure(figsize=(10, 7))
plt.xticks(np.arange(0, 33, 2), fontsize=11)
plt.yticks(np.arange(0, 33, 2), fontsize=11)
plt.imshow(z, origin="lower", vmin=-vlim, vmax=vlim, cmap="RdBu_r")
plt.colorbar(shrink=0.9)
if save_fname:
plt.savefig(save_fname, dpi=200, bbox_inches='tight')
plt.show()
def dispatch(self, time=None):
"""对self.state进行调度.
:param time: str"""
self.time = time
# 1. 查找出该时刻所有异常的点
pos_arr = np.stack(np.nonzero(np.abs(self.state) >= self.dispatch_thresh), axis=-1) # shape(NUM, 2)
pos_arr = sorted(pos_arr, key=lambda pos: -abs(self.state[tuple(pos)])) # -: 降序
# 2. 调度
for pos in pos_arr:
pos = tuple(pos)
if abs(self.state[pos]) < self.dispatch_thresh: # 在其他点调度后,该点正常了
continue
self._dispatch_pos(pos)
def print_logs(self, clear=False):
"""点坐标为 先x后y"""
for log in self.logs:
print(log)
if clear:
self.clear_logs()
def clear_logs(self):
self.logs = []
def _dispatch_pos(self, start_pos): # 算法algorithm
"""调度一个点. 算法: 广搜
:param start_pos: tuple(y, x). 需要被调度的点
"""
# --------------------------------------------- 数据结构
state = self.state
pri_queue = SortedList(key=lambda item: item[0]) # 优先级队列priority queue. item: tuple(distance, pos)
dires = [(-1, 0), (0, -1), (0, 1), (1, 0)] # 4个方向 directions (dy, dx)
is_visited = np.zeros_like(state, dtype=np.bool)
# --------------------------------------------- 算法
# 1. 广搜/A*
pri_queue.add([0, start_pos]) # (距起始点距离, 位置)
while len(pri_queue) != 0: # 队列不为空
# ----------- 当前点的处理
dist, curr_pos = pri_queue.pop(0) # 当前点current, 距离最短优先
# 非初始点, 超运输阀值, 符号相反
if curr_pos != start_pos and abs(state[curr_pos]) >= self.trans_thresh and \
1 * state[start_pos] * state[curr_pos] < 0:
# 由于会出现增现象的区域,通常会继续出现增量的情况,所以允许 超前运输
trans_num = np.clip(state[curr_pos], -self.max_trans, self.max_trans) # 最多运输限制
state[curr_pos] -= trans_num
state[start_pos] += trans_num
self.dispatch_num += 1
log = "id: %06d, 时间: %s, %s -> %s, 运输车辆数: %d, 距离: %.4fkm." % \
(self.dispatch_num, self.time, (start_pos[1], start_pos[0]), (curr_pos[1], curr_pos[0]),
int(trans_num), dist)
self.logs.append(log)
if 1 * state[start_pos] * (state[start_pos] - trans_num) <= 0: # 反号了. 调度完成
break
# ---------- 拓展. 遍历当前点的4个方向
for dy, dx in dires:
pos = curr_pos[0] + dy, curr_pos[1] + dx
# 1. 越界
if pos[0] < 0 or pos[1] < 0 or pos[0] >= state.shape[0] or pos[1] >= state.shape[1]:
continue
# 2. 是否被访问过
if is_visited[pos]:
continue
else:
is_visited[pos] = True
# 3. 对超过距离阀值的去除
dist_start = distance(start_pos[0] - pos[0], start_pos[1] - pos[1]) # curr_pos距开始点距离
if dist_start > self.distance_thresh: # 不过大于最短距离值,则continue
continue
# 4. 插入pri_queue
pri_queue.add([dist_start, pos])
你好,我是有问必答小助手。为了技术专家团更好地为您解答问题,烦请您补充下(1)问题背景详情,(2)您想解决的具体问题,(3)问题相关代码图片或者报错信息。便于技术专家团更好地理解问题,并给出解决方案。
您可以点击问题下方的【编辑】,进行补充修改问题。