{-# LANGUAGE GADTs #-}
module Tile.Collective
(
Collective (..),
interpret,
runCollective,
)
where
import Data.Map.Strict qualified as Map
import Tile.Execution
( allReduceResult,
broadcastResult,
gatherResult,
reduceResult,
scatterResult,
)
import Tile.Execution.Concurrent
( runAllReduce,
runBroadcast,
runGather,
runReduce,
runScatter,
)
import Tile.Schedule (Schedule)
data Collective m a b where
Broadcast :: a -> Collective m a (Map.Map m a)
Scatter :: Map.Map m a -> Collective m a (Map.Map m a)
Gather :: Map.Map m a -> Collective m a [(m, a)]
Reduce :: Map.Map m a -> (a -> a -> a) -> Collective m a a
AllReduce :: Map.Map m a -> (a -> a -> a) -> Collective m a (Map.Map m a)
interpret :: (Ord m) => Schedule m -> m -> Collective m a b -> b
interpret :: forall m a b. Ord m => Schedule m -> m -> Collective m a b -> b
interpret Schedule m
schedule m
root Collective m a b
collective = case Collective m a b
collective of
Broadcast a
payload ->
Schedule m -> m -> a -> Map m a
forall m p. Ord m => Schedule m -> m -> p -> Map m p
broadcastResult Schedule m
schedule m
root a
payload
Scatter Map m a
vals ->
Schedule m -> m -> [(m, a)] -> Map m a
forall m p. Ord m => Schedule m -> m -> [(m, p)] -> Map m p
scatterResult Schedule m
schedule m
root (Map m a -> [(m, a)]
forall k a. Map k a -> [(k, a)]
Map.toList Map m a
vals)
Gather Map m a
vals ->
Schedule m -> m -> Map m a -> [(m, a)]
forall m v. Ord m => Schedule m -> m -> Map m v -> [(m, v)]
gatherResult Schedule m
schedule m
root Map m a
vals
Reduce Map m a
vals a -> a -> a
combine ->
Schedule m -> m -> Map m b -> (b -> b -> b) -> b
forall m v.
Ord m =>
Schedule m -> m -> Map m v -> (v -> v -> v) -> v
reduceResult Schedule m
schedule m
root Map m a
Map m b
vals a -> a -> a
b -> b -> b
combine
AllReduce Map m a
vals a -> a -> a
combine ->
Schedule m -> m -> Map m a -> (a -> a -> a) -> Map m a
forall m v.
Ord m =>
Schedule m -> m -> Map m v -> (v -> v -> v) -> Map m v
allReduceResult Schedule m
schedule m
root Map m a
vals a -> a -> a
combine
runCollective :: (Ord m) => Schedule m -> m -> Collective m a b -> IO b
runCollective :: forall m a b. Ord m => Schedule m -> m -> Collective m a b -> IO b
runCollective Schedule m
schedule m
root Collective m a b
collective = case Collective m a b
collective of
Broadcast a
payload ->
Schedule m -> m -> a -> IO (Map m a)
forall m p. Ord m => Schedule m -> m -> p -> IO (Map m p)
runBroadcast Schedule m
schedule m
root a
payload
Scatter Map m a
vals ->
Schedule m -> [(m, a)] -> m -> IO (Map m a)
forall m a. Ord m => Schedule m -> [(m, a)] -> m -> IO (Map m a)
runScatter Schedule m
schedule (Map m a -> [(m, a)]
forall k a. Map k a -> [(k, a)]
Map.toList Map m a
vals) m
root
Gather Map m a
vals ->
Schedule m -> Map m a -> m -> IO [(m, a)]
forall m a. Ord m => Schedule m -> Map m a -> m -> IO [(m, a)]
runGather Schedule m
schedule Map m a
vals m
root
Reduce Map m a
vals a -> a -> a
combine ->
Schedule m -> Map m b -> (b -> b -> b) -> m -> IO b
forall m v.
Ord m =>
Schedule m -> Map m v -> (v -> v -> v) -> m -> IO v
runReduce Schedule m
schedule Map m a
Map m b
vals a -> a -> a
b -> b -> b
combine m
root
AllReduce Map m a
vals a -> a -> a
combine ->
Schedule m -> m -> Map m a -> (a -> a -> a) -> IO (Map m a)
forall m v.
Ord m =>
Schedule m -> m -> Map m v -> (v -> v -> v) -> IO (Map m v)
runAllReduce Schedule m
schedule m
root Map m a
vals a -> a -> a
combine