Skip to content

Commit 6975eaf

Browse files
authored
Merge pull request #35 from xkollar/reduce-sort
Remove unnecessary sorts
2 parents 5507e38 + ae46cb8 commit 6975eaf

File tree

3 files changed

+40
-27
lines changed

3 files changed

+40
-27
lines changed

src/Data/DecisionDiagram/BDD.hs

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ import Control.Monad.ST
124124
import Control.Monad.ST.Unsafe
125125
import Data.Bits (Bits (shiftL))
126126
import qualified Data.Foldable as Foldable
127-
import Data.Function (on)
128127
import Data.Hashable
129128
import qualified Data.HashMap.Lazy as HashMap
130129
import qualified Data.HashTable.Class as H
@@ -133,7 +132,6 @@ import Data.IntMap (IntMap)
133132
import qualified Data.IntMap as IntMap
134133
import Data.IntSet (IntSet)
135134
import qualified Data.IntSet as IntSet
136-
import Data.List (sortBy)
137135
import Data.Map.Lazy (Map)
138136
import qualified Data.Map.Lazy as Map
139137
import Data.Proxy
@@ -384,7 +382,7 @@ pbAtLeast :: forall a w. (ItemOrder a, Real w) => IntMap w -> w -> BDD a
384382
pbAtLeast xs k0 = unfoldOrd f (0, k0)
385383
where
386384
xs' :: V.Vector (Int, w)
387-
xs' = V.fromList $ sortBy (compareItem (Proxy :: Proxy a) `on` fst) $ IntMap.toList xs
385+
xs' = V.fromList $ mapToList (Proxy :: Proxy a) xs
388386
ys :: V.Vector (w, w)
389387
ys = V.scanr (\(_, w) (lb,ub) -> if w >= 0 then (lb, ub+w) else (lb+w, ub)) (0,0) xs'
390388

@@ -402,7 +400,7 @@ pbAtMost :: forall a w. (ItemOrder a, Real w) => IntMap w -> w -> BDD a
402400
pbAtMost xs k0 = unfoldOrd f (0, k0)
403401
where
404402
xs' :: V.Vector (Int, w)
405-
xs' = V.fromList $ sortBy (compareItem (Proxy :: Proxy a) `on` fst) $ IntMap.toList xs
403+
xs' = V.fromList $ mapToList (Proxy :: Proxy a) xs
406404
ys :: V.Vector (w, w)
407405
ys = V.scanr (\(_, w) (lb,ub) -> if w >= 0 then (lb, ub+w) else (lb+w, ub)) (0,0) xs'
408406

@@ -422,7 +420,7 @@ pbExactly :: forall a w. (ItemOrder a, Real w) => IntMap w -> w -> BDD a
422420
pbExactly xs k0 = unfoldOrd f (0, k0)
423421
where
424422
xs' :: V.Vector (Int, w)
425-
xs' = V.fromList $ sortBy (compareItem (Proxy :: Proxy a) `on` fst) $ IntMap.toList xs
423+
xs' = V.fromList $ mapToList (Proxy :: Proxy a) xs
426424
ys :: V.Vector (w, w)
427425
ys = V.scanr (\(_, w) (lb,ub) -> if w >= 0 then (lb, ub+w) else (lb+w, ub)) (0,0) xs'
428426

@@ -440,7 +438,7 @@ pbExactlyIntegral :: forall a w. (ItemOrder a, Real w, Integral w) => IntMap w -
440438
pbExactlyIntegral xs k0 = unfoldOrd f (0, k0)
441439
where
442440
xs' :: V.Vector (Int, w)
443-
xs' = V.fromList $ sortBy (compareItem (Proxy :: Proxy a) `on` fst) $ IntMap.toList xs
441+
xs' = V.fromList $ mapToList (Proxy :: Proxy a) xs
444442
ys :: V.Vector (w, w)
445443
ys = V.scanr (\(_, w) (lb,ub) -> if w >= 0 then (lb, ub+w) else (lb+w, ub)) (0,0) xs'
446444
ds :: V.Vector w
@@ -534,7 +532,7 @@ forAllSet vars bdd = runST $ do
534532
H.insert h n ret
535533
return ret
536534
f _ a = return a
537-
f (sortBy (compareItem (Proxy :: Proxy a)) (IntSet.toList vars)) bdd
535+
f (setToList (Proxy :: Proxy a) vars) bdd
538536

539537
-- | Existential quantification (∃) over a set of variables
540538
existsSet :: forall a. ItemOrder a => IntSet -> BDD a -> BDD a
@@ -556,7 +554,7 @@ existsSet vars bdd = runST $ do
556554
H.insert h n ret
557555
return ret
558556
f _ a = return a
559-
f (sortBy (compareItem (Proxy :: Proxy a)) (IntSet.toList vars)) bdd
557+
f (setToList (Proxy :: Proxy a) vars) bdd
560558

561559
-- | Unique existential quantification (∃!) over a set of variables
562560
existsUniqueSet :: forall a. ItemOrder a => IntSet -> BDD a -> BDD a
@@ -580,7 +578,7 @@ existsUniqueSet vars bdd = runST $ do
580578
return ret
581579
f (_ : _) _ = return F
582580
f [] a = return a
583-
f (sortBy (compareItem (Proxy :: Proxy a)) (IntSet.toList vars)) bdd
581+
f (setToList (Proxy :: Proxy a) vars) bdd
584582

585583
-- ------------------------------------------------------------------------
586584

@@ -714,7 +712,7 @@ restrictSet val bdd = runST $ do
714712
EQ -> if v then f xs hi else f xs lo
715713
H.insert h n ret
716714
return ret
717-
f (sortBy (compareItem (Proxy :: Proxy a) `on` fst) (IntMap.toList val)) bdd
715+
f (mapToList (Proxy :: Proxy a) val) bdd
718716

719717
-- | Compute generalized cofactor of F with respect to C.
720718
--
@@ -916,7 +914,7 @@ findSatCompleteM xs0 bdd = runST $ do
916914
p <- ps
917915
foldM (\m y -> msum [return (IntMap.insert y v m) | v <- [False, True]]) p ys
918916
_ -> error ("findSatCompleteM: " ++ show x ++ " should not occur")
919-
f (sortBy (compareItem (Proxy :: Proxy a)) (IntSet.toList xs0)) bdd
917+
f (setToList (Proxy :: Proxy a) xs0) bdd
920918

921919
-- | Find one satisfying (complete) assignment over a given set of variables
922920
--
@@ -962,7 +960,7 @@ countSat xs bdd = runST $ do
962960
return n
963961
return $! n `shiftL` length zs
964962
(_, _) -> error ("countSat: " ++ show x ++ " should not occur")
965-
f (sortBy (compareItem (Proxy :: Proxy a)) (IntSet.toList xs)) bdd
963+
f (setToList (Proxy :: Proxy a) xs) bdd
966964

967965
-- | Sample an assignment from uniform distribution over complete satisfiable assignments ('allSatComplete') of the BDD.
968966
--
@@ -1030,7 +1028,7 @@ uniformSatM xs0 bdd0 = func IntMap.empty
10301028
func0 (IntMap.insert x False a) gen
10311029
H.insert h bdd (s, func')
10321030
return (s, func')
1033-
snd <$> f (sortBy (compareItem (Proxy :: Proxy a)) (IntSet.toList xs0)) bdd0
1031+
snd <$> f (setToList (Proxy :: Proxy a) xs0) bdd0
10341032

10351033
-- ------------------------------------------------------------------------
10361034

src/Data/DecisionDiagram/BDD/Internal/ItemOrder.hs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,29 +33,45 @@ module Data.DecisionDiagram.BDD.Internal.ItemOrder
3333
, Level (..)
3434
) where
3535

36+
import Data.Function (on)
3637
import Data.Kind (Type)
38+
import Data.List (sortBy)
3739
import Data.Proxy
3840
import Data.Reflection
3941

42+
import Data.IntMap (IntMap)
43+
import qualified Data.IntMap as IntMap
44+
import Data.IntSet (IntSet)
45+
import qualified Data.IntSet as IntSet
46+
47+
4048
-- ------------------------------------------------------------------------
4149

4250
class ItemOrder (a :: Type) where
4351
compareItem :: proxy a -> Int -> Int -> Ordering
52+
mapToList :: proxy a -> IntMap b -> [(Int,b)]
53+
setToList :: proxy a -> IntSet -> [Int]
4454

4555
data AscOrder
4656

4757
data DescOrder
4858

4959
instance ItemOrder AscOrder where
5060
compareItem _ = compare
61+
mapToList _ = IntMap.toAscList
62+
setToList _ = IntSet.toAscList
5163

5264
instance ItemOrder DescOrder where
5365
compareItem _ = flip compare
66+
mapToList _ = IntMap.toDescList
67+
setToList _ = IntSet.toDescList
5468

5569
data CustomOrder a
5670

5771
instance Reifies s (Int -> Int -> Ordering) => ItemOrder (CustomOrder s) where
5872
compareItem _ = reflect (Proxy :: Proxy s)
73+
mapToList o m = sortBy (compareItem o `on` fst) (IntMap.toList m)
74+
setToList o s = sortBy (compareItem o) (IntSet.toList s)
5975

6076
withAscOrder :: forall r. (Proxy AscOrder -> r) -> r
6177
withAscOrder k = k Proxy

src/Data/DecisionDiagram/ZDD.hs

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -142,17 +142,15 @@ import Control.Monad.Primitive
142142
#endif
143143
import Control.Monad.ST
144144
import qualified Data.Foldable as Foldable
145-
import Data.Function (on)
146145
import Data.Hashable
147146
import Data.HashMap.Lazy (HashMap)
148147
import qualified Data.HashMap.Lazy as HashMap
149148
import qualified Data.HashTable.Class as H
150149
import qualified Data.HashTable.ST.Cuckoo as C
151150
import Data.IntMap (IntMap)
152-
import qualified Data.IntMap as IntMap
153151
import Data.IntSet (IntSet)
154152
import qualified Data.IntSet as IntSet
155-
import Data.List (foldl', sortBy)
153+
import Data.List (foldl')
156154
import Data.Map.Lazy (Map)
157155
import qualified Data.Map.Lazy as Map
158156
import Data.Maybe
@@ -235,7 +233,7 @@ instance ItemOrder a => Exts.IsList (ZDD a) where
235233
fromList = fromListOfSortedList . map f
236234
where
237235
f :: IntSet -> [Int]
238-
f = sortBy (compareItem (Proxy :: Proxy a)) . IntSet.toList
236+
f = setToList (Proxy :: Proxy a)
239237

240238
toList = toListOfIntSets
241239

@@ -283,7 +281,7 @@ singleton xs = insert xs empty
283281

284282
-- | Set of all subsets, i.e. powerset
285283
subsets :: forall a. ItemOrder a => IntSet -> ZDD a
286-
subsets = foldl' f Base . sortBy (flip (compareItem (Proxy :: Proxy a))) . IntSet.toList
284+
subsets = foldl' f Base . reverse . setToList (Proxy :: Proxy a)
287285
where
288286
f zdd x = Branch x zdd zdd
289287

@@ -293,7 +291,7 @@ combinations xs k
293291
| k < 0 = error "Data.DecisionDiagram.ZDD.combinations: negative size"
294292
| otherwise = unfoldOrd f (0, k)
295293
where
296-
table = V.fromList $ sortBy (compareItem (Proxy :: Proxy a)) $ IntSet.toList xs
294+
table = V.fromList $ setToList (Proxy :: Proxy a) xs
297295
n = V.length table
298296

299297
f :: (Int, Int) -> Sig (Int, Int)
@@ -307,7 +305,7 @@ subsetsAtLeast :: forall a w. (ItemOrder a, Real w) => IntMap w -> w -> ZDD a
307305
subsetsAtLeast xs k0 = unfoldOrd f (0, k0)
308306
where
309307
xs' :: V.Vector (Int, w)
310-
xs' = V.fromList $ sortBy (compareItem (Proxy :: Proxy a) `on` fst) $ IntMap.toList xs
308+
xs' = V.fromList $ mapToList (Proxy :: Proxy a) xs
311309
ys :: V.Vector (w, w)
312310
ys = V.scanr (\(_, w) (lb,ub) -> if w >= 0 then (lb, ub+w) else (lb+w, ub)) (0,0) xs'
313311

@@ -326,7 +324,7 @@ subsetsAtMost :: forall a w. (ItemOrder a, Real w) => IntMap w -> w -> ZDD a
326324
subsetsAtMost xs k0 = unfoldOrd f (0, k0)
327325
where
328326
xs' :: V.Vector (Int, w)
329-
xs' = V.fromList $ sortBy (compareItem (Proxy :: Proxy a) `on` fst) $ IntMap.toList xs
327+
xs' = V.fromList $ mapToList (Proxy :: Proxy a) xs
330328
ys :: V.Vector (w, w)
331329
ys = V.scanr (\(_, w) (lb,ub) -> if w >= 0 then (lb, ub+w) else (lb+w, ub)) (0,0) xs'
332330

@@ -349,7 +347,7 @@ subsetsExactly :: forall a w. (ItemOrder a, Real w) => IntMap w -> w -> ZDD a
349347
subsetsExactly xs k0 = unfoldOrd f (0, k0)
350348
where
351349
xs' :: V.Vector (Int, w)
352-
xs' = V.fromList $ sortBy (compareItem (Proxy :: Proxy a) `on` fst) $ IntMap.toList xs
350+
xs' = V.fromList $ mapToList (Proxy :: Proxy a) xs
353351
ys :: V.Vector (w, w)
354352
ys = V.scanr (\(_, w) (lb,ub) -> if w >= 0 then (lb, ub+w) else (lb+w, ub)) (0,0) xs'
355353

@@ -367,7 +365,7 @@ subsetsExactlyIntegral :: forall a w. (ItemOrder a, Real w, Integral w) => IntMa
367365
subsetsExactlyIntegral xs k0 = unfoldOrd f (0, k0)
368366
where
369367
xs' :: V.Vector (Int, w)
370-
xs' = V.fromList $ sortBy (compareItem (Proxy :: Proxy a) `on` fst) $ IntMap.toList xs
368+
xs' = V.fromList $ mapToList (Proxy :: Proxy a) xs
371369
ys :: V.Vector (w, w)
372370
ys = V.scanr (\(_, w) (lb,ub) -> if w >= 0 then (lb, ub+w) else (lb+w, ub)) (0,0) xs'
373371
ds :: V.Vector w
@@ -433,7 +431,7 @@ subset0 var zdd = runST $ do
433431
-- >>> toSetOfIntSets (insert (IntSet.fromList [1,2,3]) (fromListOfIntSets (map IntSet.fromList [[1,3], [2,4]])) :: ZDD AscOrder)
434432
-- fromList [fromList [1,2,3],fromList [1,3],fromList [2,4]]
435433
insert :: forall a. ItemOrder a => IntSet -> ZDD a -> ZDD a
436-
insert xs = f (sortBy (compareItem (Proxy :: Proxy a)) (IntSet.toList xs))
434+
insert xs = f (setToList (Proxy :: Proxy a) xs)
437435
where
438436
f [] (Leaf _) = Base
439437
f [] (Branch top p0 p1) = Branch top (f [] p0) p1
@@ -450,7 +448,7 @@ insert xs = f (sortBy (compareItem (Proxy :: Proxy a)) (IntSet.toList xs))
450448
-- >>> toSetOfIntSets (delete (IntSet.fromList [1,3]) (fromListOfIntSets (map IntSet.fromList [[1,2,3], [1,3], [2,4]])) :: ZDD AscOrder)
451449
-- fromList [fromList [1,2,3],fromList [2,4]]
452450
delete :: forall a. ItemOrder a => IntSet -> ZDD a -> ZDD a
453-
delete xs = f (sortBy (compareItem (Proxy :: Proxy a)) (IntSet.toList xs))
451+
delete xs = f (setToList (Proxy :: Proxy a) xs)
454452
where
455453
f [] (Leaf _) = Empty
456454
f [] (Branch top p0 p1) = Branch top (f [] p0) p1
@@ -731,7 +729,7 @@ minimalHittingSets = minimalHittingSetsToda
731729
member :: forall a. (ItemOrder a) => IntSet -> ZDD a -> Bool
732730
member xs = member' xs'
733731
where
734-
xs' = sortBy (compareItem (Proxy :: Proxy a)) $ IntSet.toList xs
732+
xs' = setToList (Proxy :: Proxy a) xs
735733

736734
member' :: forall a. (ItemOrder a) => [Int] -> ZDD a -> Bool
737735
member' [] Base = True
@@ -808,7 +806,7 @@ fromListOfIntSets :: forall a. ItemOrder a => [IntSet] -> ZDD a
808806
fromListOfIntSets = fromListOfSortedList . map f
809807
where
810808
f :: IntSet -> [Int]
811-
f = sortBy (compareItem (Proxy :: Proxy a)) . IntSet.toList
809+
f = setToList (Proxy :: Proxy a)
812810

813811
-- | Convert the family to a list of 'IntSet'.
814812
toListOfIntSets :: ZDD a -> [IntSet]
@@ -962,6 +960,7 @@ findMinSum weight =
962960
-- \max_{X\in S} \sum_{x\in X} w(x)
963961
-- \]
964962
--
963+
-- >>> import qualified Data.IntMap as IntMap
965964
-- >>> findMaxSum (IntMap.fromList [(1,2),(2,4),(3,-3)] IntMap.!) (fromListOfIntSets (map IntSet.fromList [[1], [2], [3], [1,2,3]]) :: ZDD AscOrder)
966965
-- (4,fromList [2])
967966
findMaxSum :: forall a w. (ItemOrder a, Num w, Ord w, HasCallStack) => (Int -> w) -> ZDD a -> (w, IntSet)

0 commit comments

Comments
 (0)