blob: 149b437db4a8c991075e010841801920c69a2046 [file] [log] [blame]
Mohammed Naser3415a2a2025-03-06 21:16:12 -05001extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{parse_macro_input, parse_quote, Data, DeriveInput, Fields};
6
7/// Attribute macro for OVSDB table structs
8///
9/// This macro automatically adds `_uuid` and `_version` fields to your struct
10/// and generates the necessary implementations for it to work with OVSDB.
11///
12/// # Example
13///
14/// ```rust
15/// use ovsdb_derive::ovsdb_object;
16/// use std::collections::HashMap;
17///
18/// #[ovsdb_object]
19/// pub struct NbGlobal {
20/// pub name: Option<String>,
21/// pub nb_cfg: Option<i64>,
22/// pub external_ids: Option<HashMap<String, String>>,
23/// }
24/// ```
25#[proc_macro_attribute]
26pub fn ovsdb_object(_attr: TokenStream, item: TokenStream) -> TokenStream {
27 // Parse the struct definition
28 let mut input = parse_macro_input!(item as DeriveInput);
29
30 // Add _uuid and _version fields if they don't exist
31 if let Data::Struct(ref mut data_struct) = input.data {
32 if let Fields::Named(ref mut fields) = data_struct.fields {
33 // Check if _uuid and _version already exist
34 let has_uuid = fields
35 .named
36 .iter()
37 .any(|f| f.ident.as_ref().is_some_and(|i| i == "_uuid"));
38 let has_version = fields
39 .named
40 .iter()
41 .any(|f| f.ident.as_ref().is_some_and(|i| i == "_version"));
42
43 // Add fields if they don't exist
44 if !has_uuid {
45 // Add _uuid field
46 fields.named.push(parse_quote! {
47 pub _uuid: Option<uuid::Uuid>
48 });
49 }
50 if !has_version {
51 // Add _version field
52 fields.named.push(parse_quote! {
53 pub _version: Option<uuid::Uuid>
54 });
55 }
56 }
57 }
58
59 // Get the name of the struct
60 let struct_name = &input.ident;
61
62 // Extract field names and types, excluding _uuid and _version
63 let mut field_names = Vec::new();
64 let mut field_types = Vec::new();
65
66 if let Data::Struct(ref data_struct) = input.data {
67 if let Fields::Named(ref fields) = data_struct.fields {
68 for field in &fields.named {
69 if let Some(ident) = &field.ident {
70 if ident == "_uuid" || ident == "_version" {
71 continue;
72 }
73 field_names.push(ident);
74 field_types.push(&field.ty);
75 }
76 }
77 }
78 }
79
80 // Generate implementations
81 let implementation = quote! {
82 // Re-export the input struct with the added fields
83 #input
84
85 // Automatically import necessary items from ovsdb-schema
86 use ::ovsdb_schema::{extract_uuid, OvsdbSerializableExt};
87
88 impl #struct_name {
89 /// Create a new instance with default values
90 pub fn new() -> Self {
91 Self {
92 #(
93 #field_names: Default::default(),
94 )*
95 _uuid: None,
96 _version: None,
97 }
98 }
99
100 /// Convert to a HashMap for OVSDB serialization
101 pub fn to_map(&self) -> std::collections::HashMap<String, serde_json::Value> {
102 let mut map = std::collections::HashMap::new();
103
104 #(
105 // Skip None values
106 let field_value = &self.#field_names;
107 if let Some(value) = field_value.to_ovsdb_json() {
108 map.insert(stringify!(#field_names).to_string(), value);
109 }
110 )*
111
112 map
113 }
114
115 /// Create from a HashMap received from OVSDB
116 pub fn from_map(map: &std::collections::HashMap<String, serde_json::Value>) -> Result<Self, String> {
117 let mut result = Self::new();
118
119 // Extract UUID if present
120 if let Some(uuid_val) = map.get("_uuid") {
121 if let Some(uuid) = extract_uuid(uuid_val) {
122 result._uuid = Some(uuid);
123 }
124 }
125
126 // Extract version if present
127 if let Some(version_val) = map.get("_version") {
128 if let Some(version) = extract_uuid(version_val) {
129 result._version = Some(version);
130 }
131 }
132
133 // Extract other fields
134 #(
135 if let Some(value) = map.get(stringify!(#field_names)) {
136 result.#field_names = <#field_types>::from_ovsdb_json(value)
137 .ok_or_else(|| format!("Failed to parse field {}", stringify!(#field_names)))?;
138 }
139 )*
140
141 Ok(result)
142 }
143 }
144
145 impl Default for #struct_name {
146 fn default() -> Self {
147 Self::new()
148 }
149 }
150
151 impl serde::Serialize for #struct_name {
152 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
153 where
154 S: serde::Serializer
155 {
156 self.to_map().serialize(serializer)
157 }
158 }
159
160 impl<'de> serde::Deserialize<'de> for #struct_name {
161 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
162 where
163 D: serde::Deserializer<'de>
164 {
165 let map = std::collections::HashMap::<String, serde_json::Value>::deserialize(deserializer)?;
166 Self::from_map(&map).map_err(serde::de::Error::custom)
167 }
168 }
169 };
170
171 // Return the modified struct and implementations
172 TokenStream::from(implementation)
173}
174
175/// Derive macro for OVSDB table structs (requires manual _uuid and _version fields)
176///
177/// This macro generates the necessary implementations for a struct to work with OVSDB.
178/// The struct must have `_uuid` and `_version` fields of type `Option<uuid::Uuid>`.
179///
180/// # Example
181///
182/// ```rust
183/// use ovsdb_derive::OVSDB;
184/// use std::collections::HashMap;
185/// use uuid::Uuid;
186///
187/// #[derive(Debug, Clone, PartialEq, OVSDB)]
188/// pub struct NbGlobal {
189/// pub name: Option<String>,
190/// pub nb_cfg: Option<i64>,
191/// pub external_ids: Option<HashMap<String, String>>,
192///
193/// // Required fields
194/// pub _uuid: Option<Uuid>,
195/// pub _version: Option<Uuid>,
196/// }
197/// ```
198#[proc_macro_derive(OVSDB)]
199pub fn ovsdb_derive(input: TokenStream) -> TokenStream {
200 // Parse the input tokens into a syntax tree
201 let input = parse_macro_input!(input as DeriveInput);
202
203 // Get the name of the struct
204 let struct_name = &input.ident;
205
206 // Check if the input is a struct
207 let fields = match &input.data {
208 Data::Struct(data_struct) => match &data_struct.fields {
209 Fields::Named(fields_named) => &fields_named.named,
210 _ => panic!("OVSDB can only be derived for structs with named fields"),
211 },
212 _ => panic!("OVSDB can only be derived for structs"),
213 };
214
215 // Extract field names and types, excluding _uuid and _version
216 let mut field_names = Vec::new();
217 let mut field_types = Vec::new();
218
219 for field in fields {
220 if let Some(ident) = &field.ident {
221 if ident == "_uuid" || ident == "_version" {
222 continue;
223 }
224 field_names.push(ident);
225 field_types.push(&field.ty);
226 }
227 }
228
229 // Generate code for the implementation
230 let expanded = quote! {
231 // Automatically import necessary items from ovsdb-schema
232 use ::ovsdb_schema::{extract_uuid, OvsdbSerializableExt};
233
234 impl #struct_name {
235 /// Create a new instance with default values
236 pub fn new() -> Self {
237 Self {
238 #(
239 #field_names: Default::default(),
240 )*
241 _uuid: None,
242 _version: None,
243 }
244 }
245
246 /// Convert to a HashMap for OVSDB serialization
247 pub fn to_map(&self) -> std::collections::HashMap<String, serde_json::Value> {
248 let mut map = std::collections::HashMap::new();
249
250 #(
251 // Skip None values
252 let field_value = &self.#field_names;
253 if let Some(value) = field_value.to_ovsdb_json() {
254 map.insert(stringify!(#field_names).to_string(), value);
255 }
256 )*
257
258 map
259 }
260
261 /// Create from a HashMap received from OVSDB
262 pub fn from_map(map: &std::collections::HashMap<String, serde_json::Value>) -> Result<Self, String> {
263 let mut result = Self::new();
264
265 // Extract UUID if present
266 if let Some(uuid_val) = map.get("_uuid") {
267 if let Some(uuid) = extract_uuid(uuid_val) {
268 result._uuid = Some(uuid);
269 }
270 }
271
272 // Extract version if present
273 if let Some(version_val) = map.get("_version") {
274 if let Some(version) = extract_uuid(version_val) {
275 result._version = Some(version);
276 }
277 }
278
279 // Extract other fields
280 #(
281 if let Some(value) = map.get(stringify!(#field_names)) {
282 result.#field_names = <#field_types>::from_ovsdb_json(value)
283 .ok_or_else(|| format!("Failed to parse field {}", stringify!(#field_names)))?;
284 }
285 )*
286
287 Ok(result)
288 }
289 }
290
291 impl Default for #struct_name {
292 fn default() -> Self {
293 Self::new()
294 }
295 }
296
297 impl serde::Serialize for #struct_name {
298 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
299 where
300 S: serde::Serializer
301 {
302 self.to_map().serialize(serializer)
303 }
304 }
305
306 impl<'de> serde::Deserialize<'de> for #struct_name {
307 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
308 where
309 D: serde::Deserializer<'de>
310 {
311 let map = std::collections::HashMap::<String, serde_json::Value>::deserialize(deserializer)?;
312 Self::from_map(&map).map_err(serde::de::Error::custom)
313 }
314 }
315 };
316
317 // Return the generated code
318 TokenStream::from(expanded)
319}