commit be71d341c03246fd37cc8033078aece7668f18e9
parent 520eb0476a70dfd35f419437671e854419abb8b6
Author: miksa234 <milutin@popovic.xyz>
Date: Mon, 28 Apr 2025 00:51:43 +0100
add model
Diffstat:
3 files changed, 24 insertions(+), 32 deletions(-)
diff --git a/model_100.pt b/model_100.pt
Binary files differ.
diff --git a/rl_arb/rl_arb/brute_force.py b/rl_arb/rl_arb/brute_force.py
@@ -39,9 +39,6 @@ def brute_force_search_trail(mdp, source, cap):
- list of values of solutions performance (profit)
"""
def dfs(state):
-# if len(state) > cap:
-# return
-
if mdp.check_win(state):
trails.append(state)
return
@@ -54,11 +51,12 @@ def brute_force_search_trail(mdp, source, cap):
trails = []
state = [(source, source, 0)]
dfs(state)
- profits = [0]
+ profits = [0.0]
for trail in trails:
profits.append(mdp.calculate_profit(trail, mdp.current_block))
- return trails, np.max(profits)
+ profits = np.array(profits)
+ return trails, np.max(profits[np.where(profits < 3.0)[0]])
@torch.no_grad()
@@ -75,6 +73,9 @@ def test_model():
problem.model.eval()
mcts_parallel = MCTSParallel(problem.mdp, ARGS_TRAINING)
+ mcts_parallel.mdp.data.to(DEVICE)
+ mcts_parallel.mdp.device = DEVICE
+
if not os.path.exists('./test'):
os.mkdir('./test')
@@ -82,7 +83,6 @@ def test_model():
loss = {}
values = {}
brute_values = {}
- np.random.seed(0)
selected_blocks = np.random.choice(problem.mdp.num_blocks, 100, replace=False)
start = problem.mdp.start_node
@@ -90,27 +90,20 @@ def test_model():
for b in selected_blocks:
problem.mdp.current_block = b
problem.mcts.mdp.current_block = b
- st = time.time()
_, brute_profit = brute_force_search_trail(problem.mdp, start, 10)
- et = time.time()
- print(et-st)
- times.append(et-st)
- brute_values[b] = np.log(brute_profit)
- print(np.mean(times))
- exit()
-
- for i in range(100):
-# problem.model.load_state_dict(
-# torch.load(
-# f"./model/model_{i}.pt",
-# weights_only = True,
-# map_location=DEVICE
-# ),
-# strict = False
-# )
+ brute_values[b] = brute_profit
+
+ for i in range(20):
+ problem.model.load_state_dict(
+ torch.load(
+ f"../model_100.pt",
+ weights_only = True,
+ map_location=DEVICE
+ ),
+ strict = False
+ )
logger.info(f"Iterations {i}/100")
-# send_telegram_message(f"Iterations {i}/100")
loss[i] = {}
values[i] = {}
@@ -121,7 +114,6 @@ def test_model():
PMemory(mcts_parallel.mdp, b) for b in selected_blocks[block-step: block]
]
- st = time.time()
while len(p_memory) > 0:
states = [mem.state for mem in p_memory]
@@ -154,17 +146,17 @@ def test_model():
if is_terminal:
cb = mem.current_block
values[i][cb] = value
- loss[i][cb] = 1-value/brute_values[cb]
-
+ loss[i][cb] = 1-np.exp(value)/brute_values[cb]
del p_memory[m]
- et = time.time()
- print((et-st)/25)
- logger.info(f"AVERAGE loss {np.mean([loss[i][k] for k in loss[i].keys()])}")
+ l = []
+ for it in loss.keys():
+ for b in loss[it].keys():
+ l.append(loss[it][b])
+ logger.info(f"AVERAGE loss {np.mean(l)}")
# send_telegram_message(f"AVERAGE loss {np.mean([loss[i][k] for k in loss[i].keys()])}")
with open(f'./test/test.pickle', "wb") as f:
pickle.dump([loss, values, brute_values], f)
-
diff --git a/rl_arb/rl_arb/rlearn.py b/rl_arb/rl_arb/rlearn.py
@@ -153,7 +153,7 @@ class AgentRLearn():
if block not in self.baseline_tracker:
self.baseline_tracker[block] = list(vs[idxs])
else:
- self.baseline_tracker[block].append(vs[idxs])
+ self.baseline_tracker[block] += list(vs[idxs])
baseline[idxs] = np.mean(self.baseline_tracker[block])