Source code for encord.orm.base_orm

# Copyright (c) 2023 Cord Technologies Limited
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

import datetime
import json
import logging
from collections import OrderedDict, abc

logger = logging.getLogger(__name__)

[docs]class BaseORM(dict): """Base ORM for all database objects.""" DB_FIELDS: OrderedDict = OrderedDict() NON_UPDATABLE_FIELDS: set = set() def __init__(self, dic): """ Construct client ORM compatible database object from dict object. Ensures strict type and attribute name check. The real k,v is stored in inner dict. :param dic: """ try: if not isinstance(dic, dict): raise TypeError("Need dict object") value = {} for k, v in dic.items(): if k in self.DB_FIELDS: types = self.DB_FIELDS[k] # Convert all types to tuple if not isinstance(types, tuple): types = (types,) # None value is allowed for some cases if v is None: value[k] = v # Normal cases where type matches required types elif isinstance(v, types): value[k] = v # Bool value is same as 0,1 in db elif v in (0, 1) and bool in types: value[k] = v # Datetime type but actually a datetime str is provided elif datetime.datetime in types: real_v = datetime.datetime.strptime(v, "%Y-%m-%d %H:%M:%S") value[k] = real_v elif dict in types: value[k] = v super().__init__(**value) except Exception as e: logger.error("Error init", exc_info=True) raise Exception("Convert failed {}".format(str(e))) def __getattr__(self, name): """ Override attribute method for easy access of field value. To be used instead of ["attr"]. Return None if there is no such attribute :param name: :return: """ if name in self: try: return self[name] except KeyError: return None else: raise AttributeError("Attribute does not exist: {}".format(name)) def __setattr__(self, name, value): """ Strict attribute name and type check. :param name: :param value: :return: """ if name in self.DB_FIELDS and (value is None or isinstance(value, self.DB_FIELDS[name])): self[name] = value else: raise AttributeError("Attribute name or type not match: {}".format(name)) def __delattr__(self, name): if name in self and name in self.DB_FIELDS: del self[name] else: super().__delattr__(self, name)
[docs] @staticmethod def from_db_row(row, db_field): """ Static method for conveniently converting db row to client object. :param row: :param db_field: :return: """ temp_dict = {} for i, attribute in enumerate(db_field): temp_dict[attribute] = row[i] return temp_dict
[docs] def to_dic(self, time_str: bool = True): """ Conveniently set client object as dict. Only considers the dict items, no other object attr will be counted Args: time_str: if set to True, will convert datetime field to str with format %Y-%m-%d %H:%M:%S. If False, will keep the original datetime type. Default will be True. """ res = {} for k, v in self.items(): if isinstance(v, datetime.datetime) and time_str is True: v = v.strftime("%Y-%m-%d %H:%M:%S") elif isinstance(v, dict): v = json.dumps(v) res[k] = v return res
[docs] def updatable_fields(self): for k, v in self.items(): if k not in self.NON_UPDATABLE_FIELDS and v is not None: yield k, v
[docs]class BaseListORM(list): """A wrapper for a list of objects of a specific ORM.""" BASE_ORM_TYPE = BaseORM def __init__(self, iter_): if not isinstance(iter_, abc.Iterable): raise Exception("Convert failed. The object is not an iterable.") values = [] for item in iter_: try: v = self.BASE_ORM_TYPE(item) values.append(v) except Exception as e: logger.error("Error init", exc_info=True) raise Exception("Convert failed {}".format(str(e))) super().__init__(values)