优化执行效率

This commit is contained in:
zhiyang7 2021-12-15 15:29:29 +08:00
parent 3b6ffb2753
commit 601dafb008
1 changed files with 3 additions and 3 deletions

View File

@ -448,7 +448,7 @@ class General_Model:
action = torch.randint(values.shape[0], (1,))[0] action = torch.randint(values.shape[0], (1,))[0]
else: else:
action = torch.argmax(values,dim=0)[0] action = torch.argmax(values,dim=0)[0]
return dict(action=action, max_value=torch.max(values)) return dict(action=action)
def share_memory(self): def share_memory(self):
self.models['landlord'].share_memory() self.models['landlord'].share_memory()
@ -516,7 +516,7 @@ class OldModel:
action = torch.randint(values.shape[0], (1,))[0] action = torch.randint(values.shape[0], (1,))[0]
else: else:
action = torch.argmax(values,dim=0)[0] action = torch.argmax(values,dim=0)[0]
return dict(action=action, max_value=torch.max(values)) return dict(action=action)
def share_memory(self): def share_memory(self):
self.models['landlord'].share_memory() self.models['landlord'].share_memory()
@ -587,7 +587,7 @@ class Model:
action = torch.randint(values.shape[0], (1,))[0] action = torch.randint(values.shape[0], (1,))[0]
else: else:
action = torch.argmax(values,dim=0)[0] action = torch.argmax(values,dim=0)[0]
return dict(action=action, max_value=torch.max(values)) return dict(action=action)
def share_memory(self): def share_memory(self):
self.models['landlord'].share_memory() self.models['landlord'].share_memory()