diff --git a/douzero/dmc/models.py b/douzero/dmc/models.py index a92727e..1aab661 100644 --- a/douzero/dmc/models.py +++ b/douzero/dmc/models.py @@ -128,7 +128,7 @@ class LandlordLstmModelLegacy(nn.Module): return { 'args': ( torch.randn(1, 5, 432, requires_grad=True, device=device), - torch.randn(1, 887, requires_grad=True, device=device), + torch.randn(1, 860, requires_grad=True, device=device), ), 'input_names': ['z_batch','x_batch'], 'output_names': ['values'], @@ -177,7 +177,7 @@ class FarmerLstmModelLegacy(nn.Module): return { 'args': ( torch.randn(1, 5, 432, requires_grad=True, device=device), - torch.randn(1, 1219, requires_grad=True, device=device), + torch.randn(1, 1192, requires_grad=True, device=device), ), 'input_names': ['z_batch','x_batch'], 'output_names': ['values'],