''' sql.py: Host SQL on Python Language table * table ==> vview vview * table ==> vview field + field ==> expr expr + field ==> expr ''' # # XXX: Please keep all sequence type attributes as list. # It eases our task that deepcopy an expr. # XXX: Never duplicate SQL_field_type objects when performing # deepcopy. SQL_field_type objects should be reused, or # remap to SQL_field_type objects of new tables. # import string from copy import copy, deepcopy from pythk import conveniences as cnv def get_tables(*args): return tuple([table(x) for x in args]) # ============================================================ class operand(object): def __init__(self): super(operand, self).__init__() pass def get_sql_name(self): raise NotImplementedError() def get_sql_name_no_alias(self): raise NotImplementedError() def __str__(self): return self.op def __apply_binary(self, opstr, other): no = expr(self) no.children.append(op(opstr)) no.children.append(normalize_operand(other)) return no def __apply_unary(self, opstr): no = expr() no.children.append(op(opstr)) no.children.append(self) return no def __add__(self, other): return self.__apply_binary('+', other) def __sub__(self, other): return self.__apply_binary('-', other) def __mul__(self, other): return self.__apply_binary('*', other) def __add__(self, other): return self.__apply_binary('/', other) def __eq__(self, other): return self.__apply_binary('=', other) def __ne__(self, other): return self.__apply_binary('!=', other) def __gt__(self, other): return self.__apply_binary('>', other) def __lt__(self, other): return self.__apply_binary('<', other) def __and__(self, other): return self.__apply_binary('AND', other) def __or__(self, other): return self.__apply_binary('OR', other) def __invert__(self): return self.__apply_unary('NOT') pass class complex_operand(operand): children = [] def __deepcopy__(self, memo): new_operand = copy(self) def only_dup_complex(child): if not isinstance(child, complex_operand): return child return deepcopy(child) new_operand.children = [only_dup_complex(child) for child in self.children] return new_operand pass def normalize_operand(o): if isinstance(o, operand): return o return const(o) class op(object): def __init__(self, op_name): super(op, self).__init__() self.op_name = op_name pass def get_sql_name(self): return self.op_name get_sql_name_no_alias = get_sql_name def __str__(self): return self.op_name pass class func_generic(complex_operand): def __init__(self, *children): super(func_generic, self).__init__() self.children = children pass def get_sql_name(self): args = [c.get_sql_name() for c in self.children] arg_str = string.join(args, ', ') return self.__class__.__name__ + '(' + arg_str + ')' def get_sql_name_no_alias(self): args = [c.get_sql_name_no_alias() for c in self.children] arg_str = string.join(args, ', ') return self.__class__.__name__ + '(' + arg_str + ')' __str__ = get_sql_name pass class foo(func_generic): pass class isnull(func_generic): def get_sql_name(self): arg_str = self.children[0].get_sql_name() return arg_str + ' ISNULL' get_sql_name_no_alias = get_sql_name pass class expr(complex_operand): def __init__(self, ex=None): super(expr, self).__init__() self.children = [] if ex: self.children.append(ex) pass pass def __apply_binary(self, opstr, other): no = self.__class__(self) no.children.append(op(opstr)) no.children.append(normalize_operand(other)) return no def __apply_unary(self, opstr): no = self.__class__() no.children.append(op(opstr)) no.children.append(self) return no def get_sql_name(self): return '(' + string.join(map(lambda x: x.get_sql_name(), self.children), ' ') + ')' def get_sql_name_no_alias(self): return '(' + string.join(map(lambda x: x.get_sql_name_no_alias(), self.children), ' ') + ')' def __str__(self): return self.get_sql_name_no_alias() pass class const(operand): def __init__(self, v): super(const, self).__init__() self.value = v pass def __deepcopy__(self, memo): return const(self.value) @staticmethod def str_escape(txt): return '\'' + txt.replace('\'', '\'\'') + '\'' def get_sql_name(self): import types if self.value is not None: if isinstance(self.value, types.StringTypes): return self.str_escape(self.value) return repr(self.value) return 'NULL' get_sql_name_no_alias = get_sql_name def __str__(self): return repr(self.value) pass # ============================================================ class sys_sym(operand): sym = None def __depcopy__(self, memo): return self def get_sql_name(self): return self.sym get_sql_name_no_alias = get_sql_name def __str__(self): return self.sym pass class all_sym(sys_sym): '''All symbol Used on count(*) to stand as all records.''' sym = '*' pass _a = all_sym() class free_var(sys_sym): '''Free variable (unknow) It is a place holder waiting for user code specifing data when executing the query.''' sym = '?' pass _q = free_var() # ============================================================= class SQL_field_type(operand): '''Define data type of SQL User create a object of a child of SQL_field_type with field name to specify the field name & data type of the field. ''' type_name = 'None' fmt_symbol = '%s' order_seq = 0 def __init__(self, name=None, table=None): super(SQL_field_type, self).__init__() self.field_name = name self.pkey = False self.uniq = False self.autoinc = False self.seq = SQL_field_type.order_seq self.table = table SQL_field_type.order_seq = SQL_field_type.order_seq + 1 pass def get_sql_name(self): return self.table.get_alias() + '.' + self.field_name def get_sql_name_no_alias(self): return self.field_name def primary(self, auto=False): self.pkey = True self.autoinc = auto return self def unique(self): self.uniq = True return self def __str__(self): return self.table.get_alias() + '.' + self.field_name def __deepcopy__(self, memo): return copy(self) @property def schema(self): if self.pkey: ptn = '%s %s primary key' else: ptn = '%s %s' pass if self.autoinc: ptn = ptn + ' autoincrement' pass if self.uniq: ptn = ptn + ' unique' pass return ptn % (self.field_name, self.type_name) pass class int_f(SQL_field_type): type_name = 'integer' fmt_symbol = '%d' pass class float_f(SQL_field_type): type_name = 'float' fmt_symbol = '%f' pass class date_f(SQL_field_type): type_name = 'date' pass class str_f(SQL_field_type): type_name = 'text' pass # ============================================================ # constraints class unique_c(object): def __init__(self, *args): super(unique_c, self).__init__() self.fields = args pass @property def schema(self): from string import join fnames = [f.get_sql_name_no_alias() for f in self.fields] return 'unique(' + join(fnames, ',') + ')' pass def unique(*args): import sys frame = sys._getframe(1) ldict = frame.f_locals for f in args: if not isinstance(f, SQL_field_type): raise ValueError('must be a squence of SQL_field_type instances') pass _uniques = ldict.setdefault('_unique_constraints', []) _uniques.append(unique_c(*args)) ldict['_unique_constraints'] = _uniques pass # ============================================================ class values(object): def __init__(self, table, kv_pairs): super(values, self).__init__() self.table = table self.kv_pairs = kv_pairs pass def gen_update_cmd(self, cond=None): return self.table.gen_update_cmd(self, cond) def gen_insert_cmd(self): return self.table.gen_insert_cmd(self) pass class value_factory(object): def __init__(self, table): super(value_factory, self).__init__() self.table = table pass def __call__(self, **kws): table = self.table for key in kws: if not isinstance(getattr(table, key), SQL_field_type): raise NameError('no such field (%s).' % (key,)) pass return values(self.table, kws.items()) pass # ============================================================ class table_meta(type): def __init__(cls, name, bases, dict): from pythk.types import is_descriptor super(table_meta, cls).__init__(name, bases, dict) cls.tab_name = name for key in dict: o = dict[key] if isinstance(o, SQL_field_type) and o.field_name == None: o.field_name = key pass pass first_obj = cls() for key in dict: o = dict[key] if isinstance(o, SQL_field_type): setattr(cls, key, getattr(first_obj, key)) elif (not key.startswith('_')) and is_descriptor(o): # method is also a descriptor remap_o = cls._remap(o) setattr(cls, key, remap_o) pass pass cls.__objs = [first_obj] pass def __getitem__(cls, idx): objs = cls.__objs if idx >= len(objs): num = (idx - len(objs) + 4) & ~0x3 objs = cls.__objs = objs + ([None] * num) pass obj = objs[idx] if not obj: obj = objs[idx] = cls() pass return obj def __mul__(cls, other): from pythk.types import is_subclass obj = cls[0] if is_subclass(other, table): other = other[0] pass return obj * other @property def _(cls): return cls[0] def __getattr__(cls, name): '''Redirect accessing to variable that not defined in class to first instance object.''' return getattr(cls.__objs[0], name) class _remap(object): '''Remap directly accessing to unbound method to bound method of first instance object, foo_table.__objs[0].''' def __init__(self, src): super(table_meta._remap, self).__init__() self.src = src pass def __get__(self, instance, owner): src = self.src if instance: tgt = src.__get__(instance, owner) else: tgt = src.__get__(owner[0], owner) pass return tgt pass pass class table(object): __metaclass__ = table_meta def __init__(self, name=None, fields=[]): object.__init__(self) cdict = self.__class__.__dict__ predefined = [cdict[key] for key in cdict if isinstance(cdict[key], SQL_field_type)] if name != None: self.tab_name = name pass apply(self.set_fields, tuple(predefined) + tuple(fields)) self.factory = value_factory(self) self.alias = 'A%x' % (id(self),) self.copy_from = self self.copies = 0 pass def get_alias(self): return self.alias def set_fields(self, *fields): self.__fields = [copy(f) for f in fields] for f in self.__fields: setattr(self, f.field_name, f) f.table = self pass def _values_2_assigns(self, values): kv_pairs = values.kv_pairs assigns = [] for k, v in kv_pairs: field = getattr(self, k) if not isinstance(field, SQL_field_type): raise TypeError('%s.%s is not a SQL_field_type object.' % (self.__class__.name, k)) v = normalize_operand(v) ass = '%s=%s' % (field.get_sql_name_no_alias(), v.get_sql_name_no_alias()) assigns.append(ass) pass return assigns def gen_update_cmd(self, values, cond=None): from string import join assigns = self._values_2_assigns(values) a_str = join(assigns, ', ') cmd = 'update %s set %s' % (self.tab_name, a_str) if cond: cmd = cmd + ' where ' + cond.get_sql_name_no_alias() pass return cmd def gen_update_from_vview_cmd(self, values, fields=None, where=None): raise NotImplementedError() def gen_insert_cmd(self, values): from string import join _keys = string.join([getattr(self, k).get_sql_name_no_alias() for k, v in values.kv_pairs], ', ') _values = string.join([normalize_operand(v).get_sql_name_no_alias() for k, v in values.kv_pairs], ', ') cmd = 'insert into %s (%s) values(%s)' % (self.tab_name, _keys, _values) return cmd def gen_delete_cmd(self, where=None): from string import join cmd = 'delete from ' + self.tab_name if where: cmd = cmd + ' where ' + where.get_sql_name_no_alias() pass return cmd def __mul__(self, other): return vview(self, other) def __str__(self): return '