-- |
-- Module      : Tile.Execution
-- Description : Pure execution semantics for schedules.
--
-- These functions interpret divergence schedules as pure results. A
-- divergence schedule has edges directed from root toward leaves; it is
-- the form produced by 'Tile.Routing.buildSchedule'. Each collective
-- derives convergence internally where the operation requires it.
--
-- These functions provide a denotational reference for the concurrent
-- interpreters in "Tile.Execution.Concurrent".
module Tile.Execution
  ( -- * Pure execution semantics
    broadcastResult,
    reduceResult,
    gatherResult,
    scatterResult,
    allReduceResult,
  )
where

import Data.Map.Strict qualified as Map
import Tile.Schedule
import Tile.Tree

-- | Deliver a payload from the root to every reachable member.
--
-- Takes a divergence schedule. The result includes the root, which
-- holds the payload from the start. This records reachable members,
-- not only receivers of schedule steps.
broadcastResult :: (Ord m) => Schedule m -> m -> p -> Map.Map m p
broadcastResult :: forall m p. Ord m => Schedule m -> m -> p -> Map m p
broadcastResult Schedule m
schedule m
root p
payload =
  let RoutedTree Tree m
tree = m -> Schedule m -> RoutedTree m
forall a. Ord a => a -> Schedule a -> RoutedTree a
scheduleTree m
root Schedule m
schedule
   in (m -> p) -> Set m -> Map m p
forall k a. (k -> a) -> Set k -> Map k a
Map.fromSet (p -> m -> p
forall a b. a -> b -> a
const p
payload) (Tree m -> Set m
forall a. Ord a => Tree a -> Set a
treeLabels Tree m
tree)

-- | Combine all member values into a single result at the root.
--
-- Takes a divergence schedule. Folds over the tree rooted at @root@,
-- combining each member's local value with those of its subtree.
--
-- Precondition: the value map contains every member reachable from
-- the root.
reduceResult :: (Ord m) => Schedule m -> m -> Map.Map m v -> (v -> v -> v) -> v
reduceResult :: forall m v.
Ord m =>
Schedule m -> m -> Map m v -> (v -> v -> v) -> v
reduceResult Schedule m
schedule m
root Map m v
values v -> v -> v
combine =
  let RoutedTree Tree m
tree = m -> Schedule m -> RoutedTree m
forall a. Ord a => a -> Schedule a -> RoutedTree a
scheduleTree m
root Schedule m
schedule
   in Tree m -> v
go Tree m
tree
  where
    go :: Tree m -> v
go (Tree m
member [Tree m]
kids) =
      (v -> v -> v) -> v -> [v] -> v
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' v -> v -> v
combine (Map m v
values Map m v -> m -> v
forall k a. Ord k => Map k a -> k -> a
Map.! m
member) ((Tree m -> v) -> [Tree m] -> [v]
forall a b. (a -> b) -> [a] -> [b]
map Tree m -> v
go [Tree m]
kids)

-- | Collect all member values at the root as a list of pairs.
--
-- Takes a divergence schedule. Values are returned in preorder over
-- the divergence tree.
--
-- Precondition: the value map contains every member reachable from
-- the root.
gatherResult :: (Ord m) => Schedule m -> m -> Map.Map m v -> [(m, v)]
gatherResult :: forall m v. Ord m => Schedule m -> m -> Map m v -> [(m, v)]
gatherResult Schedule m
schedule m
root Map m v
values =
  let RoutedTree Tree m
tree = m -> Schedule m -> RoutedTree m
forall a. Ord a => a -> Schedule a -> RoutedTree a
scheduleTree m
root Schedule m
schedule
   in Tree m -> [(m, v)]
go Tree m
tree
  where
    go :: Tree m -> [(m, v)]
go (Tree m
member [Tree m]
kids) =
      (m
member, Map m v
values Map m v -> m -> v
forall k a. Ord k => Map k a -> k -> a
Map.! m
member) (m, v) -> [(m, v)] -> [(m, v)]
forall a. a -> [a] -> [a]
: (Tree m -> [(m, v)]) -> [Tree m] -> [(m, v)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Tree m -> [(m, v)]
go [Tree m]
kids

-- | Deliver destination-specific payloads from the root.
--
-- Takes a divergence schedule. Only payloads whose destinations are
-- reachable from the root are delivered.
scatterResult :: (Ord m) => Schedule m -> m -> [(m, p)] -> Map.Map m p
scatterResult :: forall m p. Ord m => Schedule m -> m -> [(m, p)] -> Map m p
scatterResult Schedule m
schedule m
root [(m, p)]
payloads =
  let RoutedTree Tree m
tree = m -> Schedule m -> RoutedTree m
forall a. Ord a => a -> Schedule a -> RoutedTree a
scheduleTree m
root Schedule m
schedule
      reachable :: Set m
reachable = Tree m -> Set m
forall a. Ord a => Tree a -> Set a
treeLabels Tree m
tree
   in Map m p -> Set m -> Map m p
forall k a. Ord k => Map k a -> Set k -> Map k a
Map.restrictKeys ([(m, p)] -> Map m p
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(m, p)]
payloads) Set m
reachable

-- | Combine all member values and deliver the result to every member.
--
-- Takes a divergence schedule. Every reachable member ends with the
-- value obtained by combining all member values with @combine@.
--
-- Precondition: the value map contains every member reachable from
-- the root.
allReduceResult :: (Ord m) => Schedule m -> m -> Map.Map m v -> (v -> v -> v) -> Map.Map m v
allReduceResult :: forall m v.
Ord m =>
Schedule m -> m -> Map m v -> (v -> v -> v) -> Map m v
allReduceResult Schedule m
schedule m
root Map m v
values v -> v -> v
combine =
  Schedule m -> m -> v -> Map m v
forall m p. Ord m => Schedule m -> m -> p -> Map m p
broadcastResult Schedule m
schedule m
root v
combined
  where
    combined :: v
combined = Schedule m -> m -> Map m v -> (v -> v -> v) -> v
forall m v.
Ord m =>
Schedule m -> m -> Map m v -> (v -> v -> v) -> v
reduceResult Schedule m
schedule m
root Map m v
values v -> v -> v
combine