Mohammed Naser | 3415a2a | 2025-03-06 21:16:12 -0500 | [diff] [blame^] | 1 | extern crate proc_macro; |
| 2 | |
| 3 | use proc_macro::TokenStream; |
| 4 | use quote::quote; |
| 5 | use syn::{parse_macro_input, parse_quote, Data, DeriveInput, Fields}; |
| 6 | |
| 7 | /// Attribute macro for OVSDB table structs |
| 8 | /// |
| 9 | /// This macro automatically adds `_uuid` and `_version` fields to your struct |
| 10 | /// and generates the necessary implementations for it to work with OVSDB. |
| 11 | /// |
| 12 | /// # Example |
| 13 | /// |
| 14 | /// ```rust |
| 15 | /// use ovsdb_derive::ovsdb_object; |
| 16 | /// use std::collections::HashMap; |
| 17 | /// |
| 18 | /// #[ovsdb_object] |
| 19 | /// pub struct NbGlobal { |
| 20 | /// pub name: Option<String>, |
| 21 | /// pub nb_cfg: Option<i64>, |
| 22 | /// pub external_ids: Option<HashMap<String, String>>, |
| 23 | /// } |
| 24 | /// ``` |
| 25 | #[proc_macro_attribute] |
| 26 | pub fn ovsdb_object(_attr: TokenStream, item: TokenStream) -> TokenStream { |
| 27 | // Parse the struct definition |
| 28 | let mut input = parse_macro_input!(item as DeriveInput); |
| 29 | |
| 30 | // Add _uuid and _version fields if they don't exist |
| 31 | if let Data::Struct(ref mut data_struct) = input.data { |
| 32 | if let Fields::Named(ref mut fields) = data_struct.fields { |
| 33 | // Check if _uuid and _version already exist |
| 34 | let has_uuid = fields |
| 35 | .named |
| 36 | .iter() |
| 37 | .any(|f| f.ident.as_ref().is_some_and(|i| i == "_uuid")); |
| 38 | let has_version = fields |
| 39 | .named |
| 40 | .iter() |
| 41 | .any(|f| f.ident.as_ref().is_some_and(|i| i == "_version")); |
| 42 | |
| 43 | // Add fields if they don't exist |
| 44 | if !has_uuid { |
| 45 | // Add _uuid field |
| 46 | fields.named.push(parse_quote! { |
| 47 | pub _uuid: Option<uuid::Uuid> |
| 48 | }); |
| 49 | } |
| 50 | if !has_version { |
| 51 | // Add _version field |
| 52 | fields.named.push(parse_quote! { |
| 53 | pub _version: Option<uuid::Uuid> |
| 54 | }); |
| 55 | } |
| 56 | } |
| 57 | } |
| 58 | |
| 59 | // Get the name of the struct |
| 60 | let struct_name = &input.ident; |
| 61 | |
| 62 | // Extract field names and types, excluding _uuid and _version |
| 63 | let mut field_names = Vec::new(); |
| 64 | let mut field_types = Vec::new(); |
| 65 | |
| 66 | if let Data::Struct(ref data_struct) = input.data { |
| 67 | if let Fields::Named(ref fields) = data_struct.fields { |
| 68 | for field in &fields.named { |
| 69 | if let Some(ident) = &field.ident { |
| 70 | if ident == "_uuid" || ident == "_version" { |
| 71 | continue; |
| 72 | } |
| 73 | field_names.push(ident); |
| 74 | field_types.push(&field.ty); |
| 75 | } |
| 76 | } |
| 77 | } |
| 78 | } |
| 79 | |
| 80 | // Generate implementations |
| 81 | let implementation = quote! { |
| 82 | // Re-export the input struct with the added fields |
| 83 | #input |
| 84 | |
| 85 | // Automatically import necessary items from ovsdb-schema |
| 86 | use ::ovsdb_schema::{extract_uuid, OvsdbSerializableExt}; |
| 87 | |
| 88 | impl #struct_name { |
| 89 | /// Create a new instance with default values |
| 90 | pub fn new() -> Self { |
| 91 | Self { |
| 92 | #( |
| 93 | #field_names: Default::default(), |
| 94 | )* |
| 95 | _uuid: None, |
| 96 | _version: None, |
| 97 | } |
| 98 | } |
| 99 | |
| 100 | /// Convert to a HashMap for OVSDB serialization |
| 101 | pub fn to_map(&self) -> std::collections::HashMap<String, serde_json::Value> { |
| 102 | let mut map = std::collections::HashMap::new(); |
| 103 | |
| 104 | #( |
| 105 | // Skip None values |
| 106 | let field_value = &self.#field_names; |
| 107 | if let Some(value) = field_value.to_ovsdb_json() { |
| 108 | map.insert(stringify!(#field_names).to_string(), value); |
| 109 | } |
| 110 | )* |
| 111 | |
| 112 | map |
| 113 | } |
| 114 | |
| 115 | /// Create from a HashMap received from OVSDB |
| 116 | pub fn from_map(map: &std::collections::HashMap<String, serde_json::Value>) -> Result<Self, String> { |
| 117 | let mut result = Self::new(); |
| 118 | |
| 119 | // Extract UUID if present |
| 120 | if let Some(uuid_val) = map.get("_uuid") { |
| 121 | if let Some(uuid) = extract_uuid(uuid_val) { |
| 122 | result._uuid = Some(uuid); |
| 123 | } |
| 124 | } |
| 125 | |
| 126 | // Extract version if present |
| 127 | if let Some(version_val) = map.get("_version") { |
| 128 | if let Some(version) = extract_uuid(version_val) { |
| 129 | result._version = Some(version); |
| 130 | } |
| 131 | } |
| 132 | |
| 133 | // Extract other fields |
| 134 | #( |
| 135 | if let Some(value) = map.get(stringify!(#field_names)) { |
| 136 | result.#field_names = <#field_types>::from_ovsdb_json(value) |
| 137 | .ok_or_else(|| format!("Failed to parse field {}", stringify!(#field_names)))?; |
| 138 | } |
| 139 | )* |
| 140 | |
| 141 | Ok(result) |
| 142 | } |
| 143 | } |
| 144 | |
| 145 | impl Default for #struct_name { |
| 146 | fn default() -> Self { |
| 147 | Self::new() |
| 148 | } |
| 149 | } |
| 150 | |
| 151 | impl serde::Serialize for #struct_name { |
| 152 | fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> |
| 153 | where |
| 154 | S: serde::Serializer |
| 155 | { |
| 156 | self.to_map().serialize(serializer) |
| 157 | } |
| 158 | } |
| 159 | |
| 160 | impl<'de> serde::Deserialize<'de> for #struct_name { |
| 161 | fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> |
| 162 | where |
| 163 | D: serde::Deserializer<'de> |
| 164 | { |
| 165 | let map = std::collections::HashMap::<String, serde_json::Value>::deserialize(deserializer)?; |
| 166 | Self::from_map(&map).map_err(serde::de::Error::custom) |
| 167 | } |
| 168 | } |
| 169 | }; |
| 170 | |
| 171 | // Return the modified struct and implementations |
| 172 | TokenStream::from(implementation) |
| 173 | } |
| 174 | |
| 175 | /// Derive macro for OVSDB table structs (requires manual _uuid and _version fields) |
| 176 | /// |
| 177 | /// This macro generates the necessary implementations for a struct to work with OVSDB. |
| 178 | /// The struct must have `_uuid` and `_version` fields of type `Option<uuid::Uuid>`. |
| 179 | /// |
| 180 | /// # Example |
| 181 | /// |
| 182 | /// ```rust |
| 183 | /// use ovsdb_derive::OVSDB; |
| 184 | /// use std::collections::HashMap; |
| 185 | /// use uuid::Uuid; |
| 186 | /// |
| 187 | /// #[derive(Debug, Clone, PartialEq, OVSDB)] |
| 188 | /// pub struct NbGlobal { |
| 189 | /// pub name: Option<String>, |
| 190 | /// pub nb_cfg: Option<i64>, |
| 191 | /// pub external_ids: Option<HashMap<String, String>>, |
| 192 | /// |
| 193 | /// // Required fields |
| 194 | /// pub _uuid: Option<Uuid>, |
| 195 | /// pub _version: Option<Uuid>, |
| 196 | /// } |
| 197 | /// ``` |
| 198 | #[proc_macro_derive(OVSDB)] |
| 199 | pub fn ovsdb_derive(input: TokenStream) -> TokenStream { |
| 200 | // Parse the input tokens into a syntax tree |
| 201 | let input = parse_macro_input!(input as DeriveInput); |
| 202 | |
| 203 | // Get the name of the struct |
| 204 | let struct_name = &input.ident; |
| 205 | |
| 206 | // Check if the input is a struct |
| 207 | let fields = match &input.data { |
| 208 | Data::Struct(data_struct) => match &data_struct.fields { |
| 209 | Fields::Named(fields_named) => &fields_named.named, |
| 210 | _ => panic!("OVSDB can only be derived for structs with named fields"), |
| 211 | }, |
| 212 | _ => panic!("OVSDB can only be derived for structs"), |
| 213 | }; |
| 214 | |
| 215 | // Extract field names and types, excluding _uuid and _version |
| 216 | let mut field_names = Vec::new(); |
| 217 | let mut field_types = Vec::new(); |
| 218 | |
| 219 | for field in fields { |
| 220 | if let Some(ident) = &field.ident { |
| 221 | if ident == "_uuid" || ident == "_version" { |
| 222 | continue; |
| 223 | } |
| 224 | field_names.push(ident); |
| 225 | field_types.push(&field.ty); |
| 226 | } |
| 227 | } |
| 228 | |
| 229 | // Generate code for the implementation |
| 230 | let expanded = quote! { |
| 231 | // Automatically import necessary items from ovsdb-schema |
| 232 | use ::ovsdb_schema::{extract_uuid, OvsdbSerializableExt}; |
| 233 | |
| 234 | impl #struct_name { |
| 235 | /// Create a new instance with default values |
| 236 | pub fn new() -> Self { |
| 237 | Self { |
| 238 | #( |
| 239 | #field_names: Default::default(), |
| 240 | )* |
| 241 | _uuid: None, |
| 242 | _version: None, |
| 243 | } |
| 244 | } |
| 245 | |
| 246 | /// Convert to a HashMap for OVSDB serialization |
| 247 | pub fn to_map(&self) -> std::collections::HashMap<String, serde_json::Value> { |
| 248 | let mut map = std::collections::HashMap::new(); |
| 249 | |
| 250 | #( |
| 251 | // Skip None values |
| 252 | let field_value = &self.#field_names; |
| 253 | if let Some(value) = field_value.to_ovsdb_json() { |
| 254 | map.insert(stringify!(#field_names).to_string(), value); |
| 255 | } |
| 256 | )* |
| 257 | |
| 258 | map |
| 259 | } |
| 260 | |
| 261 | /// Create from a HashMap received from OVSDB |
| 262 | pub fn from_map(map: &std::collections::HashMap<String, serde_json::Value>) -> Result<Self, String> { |
| 263 | let mut result = Self::new(); |
| 264 | |
| 265 | // Extract UUID if present |
| 266 | if let Some(uuid_val) = map.get("_uuid") { |
| 267 | if let Some(uuid) = extract_uuid(uuid_val) { |
| 268 | result._uuid = Some(uuid); |
| 269 | } |
| 270 | } |
| 271 | |
| 272 | // Extract version if present |
| 273 | if let Some(version_val) = map.get("_version") { |
| 274 | if let Some(version) = extract_uuid(version_val) { |
| 275 | result._version = Some(version); |
| 276 | } |
| 277 | } |
| 278 | |
| 279 | // Extract other fields |
| 280 | #( |
| 281 | if let Some(value) = map.get(stringify!(#field_names)) { |
| 282 | result.#field_names = <#field_types>::from_ovsdb_json(value) |
| 283 | .ok_or_else(|| format!("Failed to parse field {}", stringify!(#field_names)))?; |
| 284 | } |
| 285 | )* |
| 286 | |
| 287 | Ok(result) |
| 288 | } |
| 289 | } |
| 290 | |
| 291 | impl Default for #struct_name { |
| 292 | fn default() -> Self { |
| 293 | Self::new() |
| 294 | } |
| 295 | } |
| 296 | |
| 297 | impl serde::Serialize for #struct_name { |
| 298 | fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> |
| 299 | where |
| 300 | S: serde::Serializer |
| 301 | { |
| 302 | self.to_map().serialize(serializer) |
| 303 | } |
| 304 | } |
| 305 | |
| 306 | impl<'de> serde::Deserialize<'de> for #struct_name { |
| 307 | fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> |
| 308 | where |
| 309 | D: serde::Deserializer<'de> |
| 310 | { |
| 311 | let map = std::collections::HashMap::<String, serde_json::Value>::deserialize(deserializer)?; |
| 312 | Self::from_map(&map).map_err(serde::de::Error::custom) |
| 313 | } |
| 314 | } |
| 315 | }; |
| 316 | |
| 317 | // Return the generated code |
| 318 | TokenStream::from(expanded) |
| 319 | } |