diff --git a/pyhealth/models/adacare.py b/pyhealth/models/adacare.py index ab8602b77..09d8f3da2 100644 --- a/pyhealth/models/adacare.py +++ b/pyhealth/models/adacare.py @@ -38,7 +38,10 @@ def forward(self, input): zs = torch.sort(input=input, dim=dim, descending=True)[0] range = torch.arange( - start=1, end=number_of_logits + 1, dtype=torch.float32 + start=1, + end=number_of_logits + 1, + dtype=input.dtype, + device=input.device, ).view(1, -1) range = range.expand_as(zs)