algebra_core_derive/
lib.rs

1#![deny(missing_docs)]
2//! # algebra-core-derive — procedural macros for algebraic traits
3//!
4//! **Part of the [postbox workspace](../index.html)**
5//!
6//! This crate provides **derive macros** for the `algebra-core` library,
7//! enabling boilerplate-free implementations of algebraic traits for
8//! product types (structs with named or unnamed fields).
9//!
10//! ## Supported derives
11//!
12//! ### Semigroup hierarchy
13//! - **`#[derive(Semigroup)]`** — implements `combine` by combining each field
14//! - **`#[derive(Monoid)]`** — implements `empty()` by calling `empty()` on each field
15//! - **`#[derive(CommutativeMonoid)]`** — marker trait requiring `Monoid`
16//! - **`#[derive(Group)]`** — implements `inverse` by inverting each field
17//! - **`#[derive(AbelianGroup)]`** — marker trait requiring `Group + CommutativeMonoid`
18//!
19//! ### Join-semilattice hierarchy
20//! - **`#[derive(JoinSemilattice)]`** — implements `join` by joining each field
21//! - **`#[derive(BoundedJoinSemilattice)]`** — implements `bottom()` by calling `bottom()` on each field
22//!
23//! ### Meet-semilattice hierarchy
24//! - **`#[derive(MeetSemilattice)]`** — implements `meet` by meeting each field
25//! - **`#[derive(BoundedMeetSemilattice)]`** — implements `top()` by calling `top()` on each field
26//!
27//! ## Usage
28//!
29//! These macros are re-exported through `algebra-core` when the `derive` feature is enabled:
30//!
31//! ```ignore
32//! use algebra_core::{Semigroup, Monoid, JoinSemilattice, BoundedJoinSemilattice};
33//!
34//! #[derive(Clone, PartialEq, Eq, Debug)]
35//! #[derive(Semigroup, Monoid, JoinSemilattice, BoundedJoinSemilattice)]
36//! struct MyLattice {
37//!     counter: algebra_core::Max<i32>,
38//!     tags: std::collections::HashSet<String>,
39//! }
40//! ```
41//!
42//! Each derive macro generates efficient componentwise implementations
43//! following standard product algebra semantics.
44use proc_macro::TokenStream;
45use quote::quote;
46use syn::parse_macro_input;
47use syn::parse_quote;
48use syn::Data;
49use syn::DeriveInput;
50use syn::Fields;
51use syn::Index;
52
53/// Represents how to access a field (by name or by index).
54enum FieldAccessor {
55    Named(syn::Ident),
56    Unnamed(Index),
57}
58
59impl quote::ToTokens for FieldAccessor {
60    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
61        match self {
62            FieldAccessor::Named(ident) => ident.to_tokens(tokens),
63            FieldAccessor::Unnamed(index) => index.to_tokens(tokens),
64        }
65    }
66}
67
68/// Extract fields from a struct (named or unnamed) and return field accessors and types.
69fn get_fields(input: &DeriveInput) -> Result<(Vec<FieldAccessor>, Vec<&syn::Type>), TokenStream> {
70    let fields = match &input.data {
71        Data::Struct(s) => &s.fields,
72        _ => {
73            let msg = "derive macros are only supported on structs";
74            return Err(syn::Error::new_spanned(&input.ident, msg)
75                .to_compile_error()
76                .into());
77        }
78    };
79
80    let (accessors, types): (Vec<_>, Vec<_>) = match fields {
81        Fields::Named(named) => named
82            .named
83            .iter()
84            .map(|f| {
85                let accessor = FieldAccessor::Named(f.ident.clone().expect("named field"));
86                (accessor, &f.ty)
87            })
88            .unzip(),
89        Fields::Unnamed(unnamed) => unnamed
90            .unnamed
91            .iter()
92            .enumerate()
93            .map(|(i, f)| {
94                let accessor = FieldAccessor::Unnamed(Index::from(i));
95                (accessor, &f.ty)
96            })
97            .unzip(),
98        Fields::Unit => (Vec::new(), Vec::new()),
99    };
100
101    Ok((accessors, types))
102}
103
104/// Helper to generate struct construction syntax.
105/// For named fields: `Name { field1: val1, field2: val2 }`
106/// For tuple fields: `Name(val0, val1)`
107/// For unit: `Name`
108fn construct_struct(
109    name: &syn::Ident,
110    fields: &Fields,
111    values: &[proc_macro2::TokenStream],
112) -> proc_macro2::TokenStream {
113    match fields {
114        Fields::Named(named) => {
115            let field_names = named.named.iter().map(|f| &f.ident);
116            quote! { #name { #( #field_names: #values ),* } }
117        }
118        Fields::Unnamed(_) => {
119            quote! { #name( #( #values ),* ) }
120        }
121        Fields::Unit => {
122            quote! { #name }
123        }
124    }
125}
126
127/// Derive macro for [`Semigroup`](https://docs.rs/algebra-core/latest/algebra_core/trait.Semigroup.html).
128///
129/// Implements `Semigroup` for a struct by combining each field componentwise.
130///
131/// # Example
132///
133/// ```ignore
134/// #[derive(Semigroup)]
135/// struct Foo {
136///     a: i32,  // i32: Semigroup (addition)
137///     b: String,  // String: Semigroup (concatenation)
138/// }
139/// ```
140#[proc_macro_derive(Semigroup)]
141pub fn derive_semigroup(input: TokenStream) -> TokenStream {
142    let input = parse_macro_input!(input as DeriveInput);
143    let name = &input.ident;
144
145    let (field_accessors, field_types) = match get_fields(&input) {
146        Ok(f) => f,
147        Err(ts) => return ts,
148    };
149
150    let mut generics = input.generics.clone();
151    {
152        let where_clause = generics.make_where_clause();
153        for ty in &field_types {
154            where_clause
155                .predicates
156                .push(parse_quote!(#ty: ::algebra_core::Semigroup));
157        }
158    }
159
160    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
161
162    // Generate combine expressions for each field
163    let combine_exprs: Vec<_> = field_accessors
164        .iter()
165        .map(|accessor| {
166            quote! { ::algebra_core::Semigroup::combine(&self.#accessor, &other.#accessor) }
167        })
168        .collect();
169
170    let fields = match &input.data {
171        Data::Struct(s) => &s.fields,
172        _ => unreachable!(),
173    };
174    let construction = construct_struct(name, fields, &combine_exprs);
175
176    let expanded = quote! {
177        impl #impl_generics ::algebra_core::Semigroup for #name #ty_generics
178        #where_clause
179        {
180            fn combine(&self, other: &Self) -> Self {
181                #construction
182            }
183        }
184    };
185
186    TokenStream::from(expanded)
187}
188
189/// Derive macro for [`Monoid`](https://docs.rs/algebra-core/latest/algebra_core/trait.Monoid.html).
190///
191/// Implements `Monoid` for a struct by constructing `empty()` from each field's empty.
192///
193/// # Example
194///
195/// ```ignore
196/// #[derive(Semigroup, Monoid)]
197/// struct Foo {
198///     a: i32,  // i32: Monoid (empty = 0)
199///     b: Vec<String>,  // Vec: Monoid (empty = [])
200/// }
201/// ```
202#[proc_macro_derive(Monoid)]
203pub fn derive_monoid(input: TokenStream) -> TokenStream {
204    let input = parse_macro_input!(input as DeriveInput);
205    let name = &input.ident;
206
207    let (_field_accessors, field_types) = match get_fields(&input) {
208        Ok(f) => f,
209        Err(ts) => return ts,
210    };
211
212    let mut generics = input.generics.clone();
213    {
214        let where_clause = generics.make_where_clause();
215        for ty in &field_types {
216            where_clause
217                .predicates
218                .push(parse_quote!(#ty: ::algebra_core::Monoid));
219        }
220    }
221
222    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
223
224    // Generate empty expressions for each field
225    let empty_exprs: Vec<_> = field_types
226        .iter()
227        .map(|_ty| {
228            quote! { ::algebra_core::Monoid::empty() }
229        })
230        .collect();
231
232    let fields = match &input.data {
233        Data::Struct(s) => &s.fields,
234        _ => unreachable!(),
235    };
236    let construction = construct_struct(name, fields, &empty_exprs);
237
238    let expanded = quote! {
239        impl #impl_generics ::algebra_core::Monoid for #name #ty_generics
240        #where_clause
241        {
242            fn empty() -> Self {
243                #construction
244            }
245        }
246    };
247
248    TokenStream::from(expanded)
249}
250
251/// Derive macro for [`CommutativeMonoid`](https://docs.rs/algebra-core/latest/algebra_core/trait.CommutativeMonoid.html).
252///
253/// Marker trait indicating that `combine` is commutative.
254/// Requires each field to implement `CommutativeMonoid`.
255///
256/// # Example
257///
258/// ```ignore
259/// #[derive(Semigroup, Monoid, CommutativeMonoid)]
260/// struct Foo {
261///     a: i32,  // addition is commutative
262///     b: std::collections::HashSet<String>,  // union is commutative
263/// }
264/// ```
265#[proc_macro_derive(CommutativeMonoid)]
266pub fn derive_commutative_monoid(input: TokenStream) -> TokenStream {
267    let input = parse_macro_input!(input as DeriveInput);
268    let name = &input.ident;
269
270    let (_field_accessors, field_types) = match get_fields(&input) {
271        Ok(f) => f,
272        Err(ts) => return ts,
273    };
274
275    let mut generics = input.generics.clone();
276    {
277        let where_clause = generics.make_where_clause();
278        for ty in &field_types {
279            where_clause
280                .predicates
281                .push(parse_quote!(#ty: ::algebra_core::CommutativeMonoid));
282        }
283    }
284
285    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
286
287    let expanded = quote! {
288        impl #impl_generics ::algebra_core::CommutativeMonoid for #name #ty_generics
289        #where_clause
290        {}
291    };
292
293    TokenStream::from(expanded)
294}
295
296/// Derive macro for [`Group`](https://docs.rs/algebra-core/latest/algebra_core/trait.Group.html).
297///
298/// Implements `Group` for a struct by inverting each field componentwise.
299///
300/// # Example
301///
302/// ```ignore
303/// #[derive(Semigroup, Monoid, Group)]
304/// struct Foo {
305///     a: i32,  // inverse = negation
306///     b: MyGroup,  // custom group
307/// }
308/// ```
309#[proc_macro_derive(Group)]
310pub fn derive_group(input: TokenStream) -> TokenStream {
311    let input = parse_macro_input!(input as DeriveInput);
312    let name = &input.ident;
313
314    let (field_accessors, field_types) = match get_fields(&input) {
315        Ok(f) => f,
316        Err(ts) => return ts,
317    };
318
319    let mut generics = input.generics.clone();
320    {
321        let where_clause = generics.make_where_clause();
322        for ty in &field_types {
323            where_clause
324                .predicates
325                .push(parse_quote!(#ty: ::algebra_core::Group));
326        }
327    }
328
329    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
330
331    // Generate inverse expressions for each field
332    let inverse_exprs: Vec<_> = field_accessors
333        .iter()
334        .map(|accessor| {
335            quote! { ::algebra_core::Group::inverse(&self.#accessor) }
336        })
337        .collect();
338
339    let fields = match &input.data {
340        Data::Struct(s) => &s.fields,
341        _ => unreachable!(),
342    };
343    let construction = construct_struct(name, fields, &inverse_exprs);
344
345    let expanded = quote! {
346        impl #impl_generics ::algebra_core::Group for #name #ty_generics
347        #where_clause
348        {
349            fn inverse(&self) -> Self {
350                #construction
351            }
352        }
353    };
354
355    TokenStream::from(expanded)
356}
357
358/// Derive macro for [`AbelianGroup`](https://docs.rs/algebra-core/latest/algebra_core/trait.AbelianGroup.html).
359///
360/// Marker trait indicating a commutative group.
361/// Requires each field to implement `AbelianGroup`.
362///
363/// # Example
364///
365/// ```ignore
366/// #[derive(Semigroup, Monoid, CommutativeMonoid, Group, AbelianGroup)]
367/// struct Foo {
368///     a: i32,  // (Z, +) is abelian
369///     b: MyAbelianGroup,
370/// }
371/// ```
372#[proc_macro_derive(AbelianGroup)]
373pub fn derive_abelian_group(input: TokenStream) -> TokenStream {
374    let input = parse_macro_input!(input as DeriveInput);
375    let name = &input.ident;
376
377    let (_field_accessors, field_types) = match get_fields(&input) {
378        Ok(f) => f,
379        Err(ts) => return ts,
380    };
381
382    let mut generics = input.generics.clone();
383    {
384        let where_clause = generics.make_where_clause();
385        for ty in &field_types {
386            where_clause
387                .predicates
388                .push(parse_quote!(#ty: ::algebra_core::AbelianGroup));
389        }
390    }
391
392    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
393
394    let expanded = quote! {
395        impl #impl_generics ::algebra_core::AbelianGroup for #name #ty_generics
396        #where_clause
397        {}
398    };
399
400    TokenStream::from(expanded)
401}
402
403/// Derive macro for [`JoinSemilattice`](https://docs.rs/algebra-core/latest/algebra_core/trait.JoinSemilattice.html).
404///
405/// Implements `JoinSemilattice` for a struct by joining each field componentwise.
406///
407/// # Example
408///
409/// ```ignore
410/// #[derive(JoinSemilattice)]
411/// struct Foo {
412///     counter: Max<i32>,  // join = max
413///     tags: HashSet<String>,  // join = union
414/// }
415/// ```
416#[proc_macro_derive(JoinSemilattice)]
417pub fn derive_join_semilattice(input: TokenStream) -> TokenStream {
418    let input = parse_macro_input!(input as DeriveInput);
419    let name = &input.ident;
420
421    let (field_accessors, field_types) = match get_fields(&input) {
422        Ok(f) => f,
423        Err(ts) => return ts,
424    };
425
426    let mut generics = input.generics.clone();
427    {
428        let where_clause = generics.make_where_clause();
429        for ty in &field_types {
430            where_clause
431                .predicates
432                .push(parse_quote!(#ty: ::algebra_core::JoinSemilattice));
433        }
434    }
435
436    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
437
438    // Generate join expressions for each field
439    let join_exprs: Vec<_> = field_accessors
440        .iter()
441        .map(|accessor| {
442            quote! { ::algebra_core::JoinSemilattice::join(&self.#accessor, &other.#accessor) }
443        })
444        .collect();
445
446    let fields = match &input.data {
447        Data::Struct(s) => &s.fields,
448        _ => unreachable!(),
449    };
450    let construction = construct_struct(name, fields, &join_exprs);
451
452    let expanded = quote! {
453        impl #impl_generics ::algebra_core::JoinSemilattice for #name #ty_generics
454        #where_clause
455        {
456            fn join(&self, other: &Self) -> Self {
457                #construction
458            }
459        }
460    };
461
462    TokenStream::from(expanded)
463}
464
465/// Derive macro for [`BoundedJoinSemilattice`](https://docs.rs/algebra-core/latest/algebra_core/trait.BoundedJoinSemilattice.html).
466///
467/// Implements `BoundedJoinSemilattice` for a struct by constructing `bottom()`
468/// from each field's bottom element.
469///
470/// # Example
471///
472/// ```ignore
473/// #[derive(JoinSemilattice, BoundedJoinSemilattice)]
474/// struct Foo {
475///     counter: Max<i32>,  // bottom = i32::MIN
476///     tags: HashSet<String>,  // bottom = ∅
477/// }
478/// ```
479#[proc_macro_derive(BoundedJoinSemilattice)]
480pub fn derive_bounded_join_semilattice(input: TokenStream) -> TokenStream {
481    let input = parse_macro_input!(input as DeriveInput);
482    let name = &input.ident;
483
484    let (_field_accessors, field_types) = match get_fields(&input) {
485        Ok(f) => f,
486        Err(ts) => return ts,
487    };
488
489    let mut generics = input.generics.clone();
490    {
491        let where_clause = generics.make_where_clause();
492        for ty in &field_types {
493            where_clause
494                .predicates
495                .push(parse_quote!(#ty: ::algebra_core::BoundedJoinSemilattice));
496        }
497    }
498
499    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
500
501    // Generate bottom expressions for each field
502    let bottom_exprs: Vec<_> = field_types
503        .iter()
504        .map(|_ty| {
505            quote! { ::algebra_core::BoundedJoinSemilattice::bottom() }
506        })
507        .collect();
508
509    let fields = match &input.data {
510        Data::Struct(s) => &s.fields,
511        _ => unreachable!(),
512    };
513    let construction = construct_struct(name, fields, &bottom_exprs);
514
515    let expanded = quote! {
516        impl #impl_generics ::algebra_core::BoundedJoinSemilattice for #name #ty_generics
517        #where_clause
518        {
519            fn bottom() -> Self {
520                #construction
521            }
522        }
523    };
524
525    TokenStream::from(expanded)
526}
527
528/// Derive macro for [`MeetSemilattice`](https://docs.rs/algebra-core/latest/algebra_core/trait.MeetSemilattice.html).
529///
530/// Implements `MeetSemilattice` for a struct by meeting each field componentwise.
531///
532/// # Example
533///
534/// ```ignore
535/// #[derive(MeetSemilattice)]
536/// struct Foo {
537///     counter: Min<i32>,  // meet = min
538///     tags: HashSet<String>,  // meet = intersection
539/// }
540/// ```
541#[proc_macro_derive(MeetSemilattice)]
542pub fn derive_meet_semilattice(input: TokenStream) -> TokenStream {
543    let input = parse_macro_input!(input as DeriveInput);
544    let name = &input.ident;
545
546    let (field_accessors, field_types) = match get_fields(&input) {
547        Ok(f) => f,
548        Err(ts) => return ts,
549    };
550
551    let mut generics = input.generics.clone();
552    {
553        let where_clause = generics.make_where_clause();
554        for ty in &field_types {
555            where_clause
556                .predicates
557                .push(parse_quote!(#ty: ::algebra_core::MeetSemilattice));
558        }
559    }
560
561    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
562
563    // Generate meet expressions for each field
564    let meet_exprs: Vec<_> = field_accessors
565        .iter()
566        .map(|accessor| {
567            quote! { ::algebra_core::MeetSemilattice::meet(&self.#accessor, &other.#accessor) }
568        })
569        .collect();
570
571    let fields = match &input.data {
572        Data::Struct(s) => &s.fields,
573        _ => unreachable!(),
574    };
575    let construction = construct_struct(name, fields, &meet_exprs);
576
577    let expanded = quote! {
578        impl #impl_generics ::algebra_core::MeetSemilattice for #name #ty_generics
579        #where_clause
580        {
581            fn meet(&self, other: &Self) -> Self {
582                #construction
583            }
584        }
585    };
586
587    TokenStream::from(expanded)
588}
589
590/// Derive macro for [`BoundedMeetSemilattice`](https://docs.rs/algebra-core/latest/algebra_core/trait.BoundedMeetSemilattice.html).
591///
592/// Implements `BoundedMeetSemilattice` for a struct by constructing `top()`
593/// from each field's top element.
594///
595/// # Example
596///
597/// ```ignore
598/// #[derive(MeetSemilattice, BoundedMeetSemilattice)]
599/// struct Foo {
600///     counter: Min<i32>,  // top = i32::MAX
601///     flag: bool,  // top = true
602/// }
603/// ```
604#[proc_macro_derive(BoundedMeetSemilattice)]
605pub fn derive_bounded_meet_semilattice(input: TokenStream) -> TokenStream {
606    let input = parse_macro_input!(input as DeriveInput);
607    let name = &input.ident;
608
609    let (_field_accessors, field_types) = match get_fields(&input) {
610        Ok(f) => f,
611        Err(ts) => return ts,
612    };
613
614    let mut generics = input.generics.clone();
615    {
616        let where_clause = generics.make_where_clause();
617        for ty in &field_types {
618            where_clause
619                .predicates
620                .push(parse_quote!(#ty: ::algebra_core::BoundedMeetSemilattice));
621        }
622    }
623
624    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
625
626    // Generate top expressions for each field
627    let top_exprs: Vec<_> = field_types
628        .iter()
629        .map(|_ty| {
630            quote! { ::algebra_core::BoundedMeetSemilattice::top() }
631        })
632        .collect();
633
634    let fields = match &input.data {
635        Data::Struct(s) => &s.fields,
636        _ => unreachable!(),
637    };
638    let construction = construct_struct(name, fields, &top_exprs);
639
640    let expanded = quote! {
641        impl #impl_generics ::algebra_core::BoundedMeetSemilattice for #name #ty_generics
642        #where_clause
643        {
644            fn top() -> Self {
645                #construction
646            }
647        }
648    };
649
650    TokenStream::from(expanded)
651}