From fd7fc4b702da099038098c1badd22413dae776a4 Mon Sep 17 00:00:00 2001 From: Yongda Fan Date: Sat, 25 Apr 2026 23:34:58 -0500 Subject: [PATCH] Fix sparsemax in AdaCare --- pyhealth/models/adacare.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pyhealth/models/adacare.py b/pyhealth/models/adacare.py index ab8602b77..09d8f3da2 100644 --- a/pyhealth/models/adacare.py +++ b/pyhealth/models/adacare.py @@ -38,7 +38,10 @@ def forward(self, input): zs = torch.sort(input=input, dim=dim, descending=True)[0] range = torch.arange( - start=1, end=number_of_logits + 1, dtype=torch.float32 + start=1, + end=number_of_logits + 1, + dtype=input.dtype, + device=input.device, ).view(1, -1) range = range.expand_as(zs)