run_predict.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. # -*- coding:utf-8 -*-
  2. """
  3. Author: BigCat
  4. """
  5. import argparse
  6. import json
  7. import time
  8. import datetime
  9. import numpy as np
  10. import tensorflow as tf
  11. from config import *
  12. from get_data import get_current_number, spider
  13. from loguru import logger
  14. parser = argparse.ArgumentParser()
  15. parser.add_argument('--name', default="ssq", type=str, help="选择训练数据: 双色球/大乐透")
  16. args = parser.parse_args()
  17. # 关闭eager模式
  18. tf.compat.v1.disable_eager_execution()
  19. if args.name == "ssq":
  20. red_graph = tf.compat.v1.Graph()
  21. with red_graph.as_default():
  22. red_saver = tf.compat.v1.train.import_meta_graph(
  23. "{}red_ball_model.ckpt.meta".format(model_args[args.name]["path"]["red"])
  24. )
  25. red_sess = tf.compat.v1.Session(graph=red_graph)
  26. red_saver.restore(red_sess, "{}red_ball_model.ckpt".format(model_args[args.name]["path"]["red"]))
  27. logger.info("已加载红球模型!")
  28. blue_graph = tf.compat.v1.Graph()
  29. with blue_graph.as_default():
  30. blue_saver = tf.compat.v1.train.import_meta_graph(
  31. "{}blue_ball_model.ckpt.meta".format(model_args[args.name]["path"]["blue"])
  32. )
  33. blue_sess = tf.compat.v1.Session(graph=blue_graph)
  34. blue_saver.restore(blue_sess, "{}blue_ball_model.ckpt".format(model_args[args.name]["path"]["blue"]))
  35. logger.info("已加载蓝球模型!")
  36. # 加载关键节点名
  37. with open("{}/{}/{}".format(model_path, args.name, pred_key_name)) as f:
  38. pred_key_d = json.load(f)
  39. current_number = get_current_number(args.name)
  40. logger.info("【{}】最近一期:{}".format(name_path[args.name]["name"], current_number))
  41. else:
  42. red_graph = tf.compat.v1.Graph()
  43. with red_graph.as_default():
  44. red_saver = tf.compat.v1.train.import_meta_graph(
  45. "{}red_ball_model.ckpt.meta".format(model_args[args.name]["path"]["red"])
  46. )
  47. red_sess = tf.compat.v1.Session(graph=red_graph)
  48. red_saver.restore(red_sess, "{}red_ball_model.ckpt".format(model_args[args.name]["path"]["red"]))
  49. logger.info("已加载红球模型!")
  50. blue_graph = tf.compat.v1.Graph()
  51. with blue_graph.as_default():
  52. blue_saver = tf.compat.v1.train.import_meta_graph(
  53. "{}blue_ball_model.ckpt.meta".format(model_args[args.name]["path"]["blue"])
  54. )
  55. blue_sess = tf.compat.v1.Session(graph=blue_graph)
  56. blue_saver.restore(blue_sess, "{}blue_ball_model.ckpt".format(model_args[args.name]["path"]["blue"]))
  57. logger.info("已加载蓝球模型!")
  58. # 加载关键节点名
  59. with open("{}/{}/{}".format(model_path,args.name , pred_key_name)) as f:
  60. pred_key_d = json.load(f)
  61. current_number = get_current_number(args.name)
  62. logger.info("【{}】最近一期:{}".format(name_path[args.name]["name"], current_number))
  63. def get_year():
  64. """ 截取年份
  65. eg:2020-->20, 2021-->21
  66. :return:
  67. """
  68. return int(str(datetime.datetime.now().year)[-2:])
  69. def try_error(mode, name, predict_features, windows_size):
  70. """ 处理异常
  71. """
  72. if mode:
  73. return predict_features
  74. else:
  75. if len(predict_features) != windows_size:
  76. logger.warning("期号出现跳期,期号不连续!开始查找最近上一期期号!本期预测时间较久!")
  77. last_current_year = (get_year() - 1) * 1000
  78. max_times = 160
  79. while len(predict_features) != 3:
  80. predict_features = spider(name, last_current_year + max_times, get_current_number(name), "predict")[[x[0] for x in ball_name]]
  81. time.sleep(np.random.random(1).tolist()[0])
  82. max_times -= 1
  83. return predict_features
  84. return predict_features
  85. def get_red_ball_predict_result(predict_features, sequence_len, windows_size):
  86. """ 获取红球预测结果
  87. """
  88. name_list = [(ball_name[0], i + 1) for i in range(sequence_len)]
  89. data = predict_features[["{}_{}".format(name[0], i) for name, i in name_list]].values.astype(int) - 1
  90. with red_graph.as_default():
  91. reverse_sequence = tf.compat.v1.get_default_graph().get_tensor_by_name(pred_key_d[ball_name[0][0]])
  92. pred = red_sess.run(reverse_sequence, feed_dict={
  93. "inputs:0": data.reshape(1, windows_size, sequence_len),
  94. "sequence_length:0": np.array([sequence_len] * 1)
  95. })
  96. return pred, name_list
  97. def get_blue_ball_predict_result(name, predict_features, sequence_len, windows_size):
  98. """ 获取蓝球预测结果
  99. """
  100. if name == "ssq":
  101. data = predict_features[[ball_name[1][0]]].values.astype(int) - 1
  102. with blue_graph.as_default():
  103. softmax = tf.compat.v1.get_default_graph().get_tensor_by_name(pred_key_d[ball_name[1][0]])
  104. pred = blue_sess.run(softmax, feed_dict={
  105. "inputs:0": data.reshape(1, windows_size)
  106. })
  107. return pred
  108. else:
  109. name_list = [(ball_name[1], i + 1) for i in range(sequence_len)]
  110. data = predict_features[["{}_{}".format(name[0], i) for name, i in name_list]].values.astype(int) - 1
  111. with blue_graph.as_default():
  112. reverse_sequence = tf.compat.v1.get_default_graph().get_tensor_by_name(pred_key_d[ball_name[1][0]])
  113. pred = blue_sess.run(reverse_sequence, feed_dict={
  114. "inputs:0": data.reshape(1, windows_size, sequence_len),
  115. "sequence_length:0": np.array([sequence_len] * 1)
  116. })
  117. return pred, name_list
  118. def get_final_result(name, predict_features, mode=0):
  119. """" 最终预测函数
  120. """
  121. m_args = model_args[name]["model_args"]
  122. if name == "ssq":
  123. red_pred, red_name_list = get_red_ball_predict_result(predict_features, m_args["sequence_len"], m_args["windows_size"])
  124. blue_pred = get_blue_ball_predict_result(name, predict_features, 0, m_args["windows_size"])
  125. ball_name_list = ["{}_{}".format(name[mode], i) for name, i in red_name_list] + [ball_name[1][mode]]
  126. pred_result_list = red_pred[0].tolist() + blue_pred.tolist()
  127. return {
  128. b_name: int(res) + 1 for b_name, res in zip(ball_name_list, pred_result_list)
  129. }
  130. else:
  131. red_pred, red_name_list = get_red_ball_predict_result(predict_features, m_args["red_sequence_len"], m_args["windows_size"])
  132. blue_pred, blue_name_list = get_blue_ball_predict_result(name, predict_features, m_args["blue_sequence_len"], m_args["windows_size"])
  133. ball_name_list = ["{}_{}".format(name[mode], i) for name, i in red_name_list] + ["{}_{}".format(name[mode], i) for name, i in blue_name_list]
  134. pred_result_list = red_pred[0].tolist() + blue_pred[0].tolist()
  135. return {
  136. b_name: int(res) + 1 for b_name, res in zip(ball_name_list, pred_result_list)
  137. }
  138. def run(name):
  139. windows_size = model_args[name]["model_args"]["windows_size"]
  140. diff_number = windows_size - 1
  141. data = spider(name, 1, current_number, "predict")
  142. # print(data)
  143. logger.info("【{}】预测期号:{}".format(name_path[name]["name"], int(current_number) + 1))
  144. predict_features_ = try_error(1, name, data.iloc[:windows_size], windows_size)
  145. logger.info("预测结果:{}".format(get_final_result(name, predict_features_)))
  146. if __name__ == '__main__':
  147. if not args.name:
  148. raise Exception("玩法名称不能为空!")
  149. else:
  150. run(args.name)