module Tile.Execution
(
broadcastResult,
reduceResult,
gatherResult,
scatterResult,
allReduceResult,
)
where
import Data.Map.Strict qualified as Map
import Tile.Schedule
import Tile.Tree
broadcastResult :: (Ord m) => Schedule m -> m -> p -> Map.Map m p
broadcastResult :: forall m p. Ord m => Schedule m -> m -> p -> Map m p
broadcastResult Schedule m
schedule m
root p
payload =
let RoutedTree Tree m
tree = m -> Schedule m -> RoutedTree m
forall a. Ord a => a -> Schedule a -> RoutedTree a
scheduleTree m
root Schedule m
schedule
in (m -> p) -> Set m -> Map m p
forall k a. (k -> a) -> Set k -> Map k a
Map.fromSet (p -> m -> p
forall a b. a -> b -> a
const p
payload) (Tree m -> Set m
forall a. Ord a => Tree a -> Set a
treeLabels Tree m
tree)
reduceResult :: (Ord m) => Schedule m -> m -> Map.Map m v -> (v -> v -> v) -> v
reduceResult :: forall m v.
Ord m =>
Schedule m -> m -> Map m v -> (v -> v -> v) -> v
reduceResult Schedule m
schedule m
root Map m v
values v -> v -> v
combine =
let RoutedTree Tree m
tree = m -> Schedule m -> RoutedTree m
forall a. Ord a => a -> Schedule a -> RoutedTree a
scheduleTree m
root Schedule m
schedule
in Tree m -> v
go Tree m
tree
where
go :: Tree m -> v
go (Tree m
member [Tree m]
kids) =
(v -> v -> v) -> v -> [v] -> v
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' v -> v -> v
combine (Map m v
values Map m v -> m -> v
forall k a. Ord k => Map k a -> k -> a
Map.! m
member) ((Tree m -> v) -> [Tree m] -> [v]
forall a b. (a -> b) -> [a] -> [b]
map Tree m -> v
go [Tree m]
kids)
gatherResult :: (Ord m) => Schedule m -> m -> Map.Map m v -> [(m, v)]
gatherResult :: forall m v. Ord m => Schedule m -> m -> Map m v -> [(m, v)]
gatherResult Schedule m
schedule m
root Map m v
values =
let RoutedTree Tree m
tree = m -> Schedule m -> RoutedTree m
forall a. Ord a => a -> Schedule a -> RoutedTree a
scheduleTree m
root Schedule m
schedule
in Tree m -> [(m, v)]
go Tree m
tree
where
go :: Tree m -> [(m, v)]
go (Tree m
member [Tree m]
kids) =
(m
member, Map m v
values Map m v -> m -> v
forall k a. Ord k => Map k a -> k -> a
Map.! m
member) (m, v) -> [(m, v)] -> [(m, v)]
forall a. a -> [a] -> [a]
: (Tree m -> [(m, v)]) -> [Tree m] -> [(m, v)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Tree m -> [(m, v)]
go [Tree m]
kids
scatterResult :: (Ord m) => Schedule m -> m -> [(m, p)] -> Map.Map m p
scatterResult :: forall m p. Ord m => Schedule m -> m -> [(m, p)] -> Map m p
scatterResult Schedule m
schedule m
root [(m, p)]
payloads =
let RoutedTree Tree m
tree = m -> Schedule m -> RoutedTree m
forall a. Ord a => a -> Schedule a -> RoutedTree a
scheduleTree m
root Schedule m
schedule
reachable :: Set m
reachable = Tree m -> Set m
forall a. Ord a => Tree a -> Set a
treeLabels Tree m
tree
in Map m p -> Set m -> Map m p
forall k a. Ord k => Map k a -> Set k -> Map k a
Map.restrictKeys ([(m, p)] -> Map m p
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(m, p)]
payloads) Set m
reachable
allReduceResult :: (Ord m) => Schedule m -> m -> Map.Map m v -> (v -> v -> v) -> Map.Map m v
allReduceResult :: forall m v.
Ord m =>
Schedule m -> m -> Map m v -> (v -> v -> v) -> Map m v
allReduceResult Schedule m
schedule m
root Map m v
values v -> v -> v
combine =
Schedule m -> m -> v -> Map m v
forall m p. Ord m => Schedule m -> m -> p -> Map m p
broadcastResult Schedule m
schedule m
root v
combined
where
combined :: v
combined = Schedule m -> m -> Map m v -> (v -> v -> v) -> v
forall m v.
Ord m =>
Schedule m -> m -> Map m v -> (v -> v -> v) -> v
reduceResult Schedule m
schedule m
root Map m v
values v -> v -> v
combine