1#![deny(missing_docs)]
2use proc_macro::TokenStream;
45use quote::quote;
46use syn::parse_macro_input;
47use syn::parse_quote;
48use syn::Data;
49use syn::DeriveInput;
50use syn::Fields;
51use syn::Index;
52
53enum FieldAccessor {
55 Named(syn::Ident),
56 Unnamed(Index),
57}
58
59impl quote::ToTokens for FieldAccessor {
60 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
61 match self {
62 FieldAccessor::Named(ident) => ident.to_tokens(tokens),
63 FieldAccessor::Unnamed(index) => index.to_tokens(tokens),
64 }
65 }
66}
67
68fn get_fields(input: &DeriveInput) -> Result<(Vec<FieldAccessor>, Vec<&syn::Type>), TokenStream> {
70 let fields = match &input.data {
71 Data::Struct(s) => &s.fields,
72 _ => {
73 let msg = "derive macros are only supported on structs";
74 return Err(syn::Error::new_spanned(&input.ident, msg)
75 .to_compile_error()
76 .into());
77 }
78 };
79
80 let (accessors, types): (Vec<_>, Vec<_>) = match fields {
81 Fields::Named(named) => named
82 .named
83 .iter()
84 .map(|f| {
85 let accessor = FieldAccessor::Named(f.ident.clone().expect("named field"));
86 (accessor, &f.ty)
87 })
88 .unzip(),
89 Fields::Unnamed(unnamed) => unnamed
90 .unnamed
91 .iter()
92 .enumerate()
93 .map(|(i, f)| {
94 let accessor = FieldAccessor::Unnamed(Index::from(i));
95 (accessor, &f.ty)
96 })
97 .unzip(),
98 Fields::Unit => (Vec::new(), Vec::new()),
99 };
100
101 Ok((accessors, types))
102}
103
104fn construct_struct(
109 name: &syn::Ident,
110 fields: &Fields,
111 values: &[proc_macro2::TokenStream],
112) -> proc_macro2::TokenStream {
113 match fields {
114 Fields::Named(named) => {
115 let field_names = named.named.iter().map(|f| &f.ident);
116 quote! { #name { #( #field_names: #values ),* } }
117 }
118 Fields::Unnamed(_) => {
119 quote! { #name( #( #values ),* ) }
120 }
121 Fields::Unit => {
122 quote! { #name }
123 }
124 }
125}
126
127#[proc_macro_derive(Semigroup)]
141pub fn derive_semigroup(input: TokenStream) -> TokenStream {
142 let input = parse_macro_input!(input as DeriveInput);
143 let name = &input.ident;
144
145 let (field_accessors, field_types) = match get_fields(&input) {
146 Ok(f) => f,
147 Err(ts) => return ts,
148 };
149
150 let mut generics = input.generics.clone();
151 {
152 let where_clause = generics.make_where_clause();
153 for ty in &field_types {
154 where_clause
155 .predicates
156 .push(parse_quote!(#ty: ::algebra_core::Semigroup));
157 }
158 }
159
160 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
161
162 let combine_exprs: Vec<_> = field_accessors
164 .iter()
165 .map(|accessor| {
166 quote! { ::algebra_core::Semigroup::combine(&self.#accessor, &other.#accessor) }
167 })
168 .collect();
169
170 let fields = match &input.data {
171 Data::Struct(s) => &s.fields,
172 _ => unreachable!(),
173 };
174 let construction = construct_struct(name, fields, &combine_exprs);
175
176 let expanded = quote! {
177 impl #impl_generics ::algebra_core::Semigroup for #name #ty_generics
178 #where_clause
179 {
180 fn combine(&self, other: &Self) -> Self {
181 #construction
182 }
183 }
184 };
185
186 TokenStream::from(expanded)
187}
188
189#[proc_macro_derive(Monoid)]
203pub fn derive_monoid(input: TokenStream) -> TokenStream {
204 let input = parse_macro_input!(input as DeriveInput);
205 let name = &input.ident;
206
207 let (_field_accessors, field_types) = match get_fields(&input) {
208 Ok(f) => f,
209 Err(ts) => return ts,
210 };
211
212 let mut generics = input.generics.clone();
213 {
214 let where_clause = generics.make_where_clause();
215 for ty in &field_types {
216 where_clause
217 .predicates
218 .push(parse_quote!(#ty: ::algebra_core::Monoid));
219 }
220 }
221
222 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
223
224 let empty_exprs: Vec<_> = field_types
226 .iter()
227 .map(|_ty| {
228 quote! { ::algebra_core::Monoid::empty() }
229 })
230 .collect();
231
232 let fields = match &input.data {
233 Data::Struct(s) => &s.fields,
234 _ => unreachable!(),
235 };
236 let construction = construct_struct(name, fields, &empty_exprs);
237
238 let expanded = quote! {
239 impl #impl_generics ::algebra_core::Monoid for #name #ty_generics
240 #where_clause
241 {
242 fn empty() -> Self {
243 #construction
244 }
245 }
246 };
247
248 TokenStream::from(expanded)
249}
250
251#[proc_macro_derive(CommutativeMonoid)]
266pub fn derive_commutative_monoid(input: TokenStream) -> TokenStream {
267 let input = parse_macro_input!(input as DeriveInput);
268 let name = &input.ident;
269
270 let (_field_accessors, field_types) = match get_fields(&input) {
271 Ok(f) => f,
272 Err(ts) => return ts,
273 };
274
275 let mut generics = input.generics.clone();
276 {
277 let where_clause = generics.make_where_clause();
278 for ty in &field_types {
279 where_clause
280 .predicates
281 .push(parse_quote!(#ty: ::algebra_core::CommutativeMonoid));
282 }
283 }
284
285 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
286
287 let expanded = quote! {
288 impl #impl_generics ::algebra_core::CommutativeMonoid for #name #ty_generics
289 #where_clause
290 {}
291 };
292
293 TokenStream::from(expanded)
294}
295
296#[proc_macro_derive(Group)]
310pub fn derive_group(input: TokenStream) -> TokenStream {
311 let input = parse_macro_input!(input as DeriveInput);
312 let name = &input.ident;
313
314 let (field_accessors, field_types) = match get_fields(&input) {
315 Ok(f) => f,
316 Err(ts) => return ts,
317 };
318
319 let mut generics = input.generics.clone();
320 {
321 let where_clause = generics.make_where_clause();
322 for ty in &field_types {
323 where_clause
324 .predicates
325 .push(parse_quote!(#ty: ::algebra_core::Group));
326 }
327 }
328
329 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
330
331 let inverse_exprs: Vec<_> = field_accessors
333 .iter()
334 .map(|accessor| {
335 quote! { ::algebra_core::Group::inverse(&self.#accessor) }
336 })
337 .collect();
338
339 let fields = match &input.data {
340 Data::Struct(s) => &s.fields,
341 _ => unreachable!(),
342 };
343 let construction = construct_struct(name, fields, &inverse_exprs);
344
345 let expanded = quote! {
346 impl #impl_generics ::algebra_core::Group for #name #ty_generics
347 #where_clause
348 {
349 fn inverse(&self) -> Self {
350 #construction
351 }
352 }
353 };
354
355 TokenStream::from(expanded)
356}
357
358#[proc_macro_derive(AbelianGroup)]
373pub fn derive_abelian_group(input: TokenStream) -> TokenStream {
374 let input = parse_macro_input!(input as DeriveInput);
375 let name = &input.ident;
376
377 let (_field_accessors, field_types) = match get_fields(&input) {
378 Ok(f) => f,
379 Err(ts) => return ts,
380 };
381
382 let mut generics = input.generics.clone();
383 {
384 let where_clause = generics.make_where_clause();
385 for ty in &field_types {
386 where_clause
387 .predicates
388 .push(parse_quote!(#ty: ::algebra_core::AbelianGroup));
389 }
390 }
391
392 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
393
394 let expanded = quote! {
395 impl #impl_generics ::algebra_core::AbelianGroup for #name #ty_generics
396 #where_clause
397 {}
398 };
399
400 TokenStream::from(expanded)
401}
402
403#[proc_macro_derive(JoinSemilattice)]
417pub fn derive_join_semilattice(input: TokenStream) -> TokenStream {
418 let input = parse_macro_input!(input as DeriveInput);
419 let name = &input.ident;
420
421 let (field_accessors, field_types) = match get_fields(&input) {
422 Ok(f) => f,
423 Err(ts) => return ts,
424 };
425
426 let mut generics = input.generics.clone();
427 {
428 let where_clause = generics.make_where_clause();
429 for ty in &field_types {
430 where_clause
431 .predicates
432 .push(parse_quote!(#ty: ::algebra_core::JoinSemilattice));
433 }
434 }
435
436 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
437
438 let join_exprs: Vec<_> = field_accessors
440 .iter()
441 .map(|accessor| {
442 quote! { ::algebra_core::JoinSemilattice::join(&self.#accessor, &other.#accessor) }
443 })
444 .collect();
445
446 let fields = match &input.data {
447 Data::Struct(s) => &s.fields,
448 _ => unreachable!(),
449 };
450 let construction = construct_struct(name, fields, &join_exprs);
451
452 let expanded = quote! {
453 impl #impl_generics ::algebra_core::JoinSemilattice for #name #ty_generics
454 #where_clause
455 {
456 fn join(&self, other: &Self) -> Self {
457 #construction
458 }
459 }
460 };
461
462 TokenStream::from(expanded)
463}
464
465#[proc_macro_derive(BoundedJoinSemilattice)]
480pub fn derive_bounded_join_semilattice(input: TokenStream) -> TokenStream {
481 let input = parse_macro_input!(input as DeriveInput);
482 let name = &input.ident;
483
484 let (_field_accessors, field_types) = match get_fields(&input) {
485 Ok(f) => f,
486 Err(ts) => return ts,
487 };
488
489 let mut generics = input.generics.clone();
490 {
491 let where_clause = generics.make_where_clause();
492 for ty in &field_types {
493 where_clause
494 .predicates
495 .push(parse_quote!(#ty: ::algebra_core::BoundedJoinSemilattice));
496 }
497 }
498
499 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
500
501 let bottom_exprs: Vec<_> = field_types
503 .iter()
504 .map(|_ty| {
505 quote! { ::algebra_core::BoundedJoinSemilattice::bottom() }
506 })
507 .collect();
508
509 let fields = match &input.data {
510 Data::Struct(s) => &s.fields,
511 _ => unreachable!(),
512 };
513 let construction = construct_struct(name, fields, &bottom_exprs);
514
515 let expanded = quote! {
516 impl #impl_generics ::algebra_core::BoundedJoinSemilattice for #name #ty_generics
517 #where_clause
518 {
519 fn bottom() -> Self {
520 #construction
521 }
522 }
523 };
524
525 TokenStream::from(expanded)
526}
527
528#[proc_macro_derive(MeetSemilattice)]
542pub fn derive_meet_semilattice(input: TokenStream) -> TokenStream {
543 let input = parse_macro_input!(input as DeriveInput);
544 let name = &input.ident;
545
546 let (field_accessors, field_types) = match get_fields(&input) {
547 Ok(f) => f,
548 Err(ts) => return ts,
549 };
550
551 let mut generics = input.generics.clone();
552 {
553 let where_clause = generics.make_where_clause();
554 for ty in &field_types {
555 where_clause
556 .predicates
557 .push(parse_quote!(#ty: ::algebra_core::MeetSemilattice));
558 }
559 }
560
561 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
562
563 let meet_exprs: Vec<_> = field_accessors
565 .iter()
566 .map(|accessor| {
567 quote! { ::algebra_core::MeetSemilattice::meet(&self.#accessor, &other.#accessor) }
568 })
569 .collect();
570
571 let fields = match &input.data {
572 Data::Struct(s) => &s.fields,
573 _ => unreachable!(),
574 };
575 let construction = construct_struct(name, fields, &meet_exprs);
576
577 let expanded = quote! {
578 impl #impl_generics ::algebra_core::MeetSemilattice for #name #ty_generics
579 #where_clause
580 {
581 fn meet(&self, other: &Self) -> Self {
582 #construction
583 }
584 }
585 };
586
587 TokenStream::from(expanded)
588}
589
590#[proc_macro_derive(BoundedMeetSemilattice)]
605pub fn derive_bounded_meet_semilattice(input: TokenStream) -> TokenStream {
606 let input = parse_macro_input!(input as DeriveInput);
607 let name = &input.ident;
608
609 let (_field_accessors, field_types) = match get_fields(&input) {
610 Ok(f) => f,
611 Err(ts) => return ts,
612 };
613
614 let mut generics = input.generics.clone();
615 {
616 let where_clause = generics.make_where_clause();
617 for ty in &field_types {
618 where_clause
619 .predicates
620 .push(parse_quote!(#ty: ::algebra_core::BoundedMeetSemilattice));
621 }
622 }
623
624 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
625
626 let top_exprs: Vec<_> = field_types
628 .iter()
629 .map(|_ty| {
630 quote! { ::algebra_core::BoundedMeetSemilattice::top() }
631 })
632 .collect();
633
634 let fields = match &input.data {
635 Data::Struct(s) => &s.fields,
636 _ => unreachable!(),
637 };
638 let construction = construct_struct(name, fields, &top_exprs);
639
640 let expanded = quote! {
641 impl #impl_generics ::algebra_core::BoundedMeetSemilattice for #name #ty_generics
642 #where_clause
643 {
644 fn top() -> Self {
645 #construction
646 }
647 }
648 };
649
650 TokenStream::from(expanded)
651}