''' sql.py: Host SQL on Python Language table * table ==> vview vview * table ==> vview field + field ==> expr expr + field ==> expr ''' # # XXX: Please keep all sequence type attributes as list. # It eases our task that deepcopy an expr. # XXX: Never duplicate SQL_field_type objects when performing # deepcopy. SQL_field_type objects should be reused, or # remap to SQL_field_type objects of new tables. # import string from copy import copy, deepcopy from pythk import conveniences as cnv def get_tables(*args): return tuple([table(x) for x in args]) # ============================================================ class operand(object): def __init__(self): super(operand, self).__init__() pass def get_sql_name(self): raise NotImplementedError() def get_sql_name_no_alias(self): raise NotImplementedError() def __str__(self): return self.op def __apply_binary(self, opstr, other): no = expr(self) no.children.append(op(opstr)) no.children.append(normalize_operand(other)) return no def __apply_unary(self, opstr): no = expr() no.children.append(op(opstr)) no.children.append(self) return no def __add__(self, other): return self.__apply_binary('+', other) def __sub__(self, other): return self.__apply_binary('-', other) def __mul__(self, other): return self.__apply_binary('*', other) def __add__(self, other): return self.__apply_binary('/', other) def __eq__(self, other): return self.__apply_binary('=', other) def __ne__(self, other): return self.__apply_binary('!=', other) def __gt__(self, other): return self.__apply_binary('>', other) def __lt__(self, other): return self.__apply_binary('<', other) def __and__(self, other): return self.__apply_binary('AND', other) def __or__(self, other): return self.__apply_binary('OR', other) def __invert__(self): return self.__apply_unary('NOT') pass class complex_operand(operand): children = [] def __deepcopy__(self, memo): new_operand = copy(self) def only_dup_complex(child): if not isinstance(child, complex_operand): return child return deepcopy(child) new_operand.children = [only_dup_complex(child) for child in self.children] return new_operand pass def normalize_operand(o): if isinstance(o, operand): return o return const(o) class op(object): def __init__(self, op_name): super(op, self).__init__() self.op_name = op_name pass def get_sql_name(self): return self.op_name get_sql_name_no_alias = get_sql_name def __str__(self): return self.op_name pass class func_generic(complex_operand): def __init__(self, *children): super(func_generic, self).__init__() self.children = children pass def get_sql_name(self): args = [c.get_sql_name() for c in self.children] arg_str = string.join(args, ', ') return self.__class__.__name__ + '(' + arg_str + ')' def get_sql_name_no_alias(self): args = [c.get_sql_name_no_alias() for c in self.children] arg_str = string.join(args, ', ') return self.__class__.__name__ + '(' + arg_str + ')' __str__ = get_sql_name pass class foo(func_generic): pass class isnull(func_generic): def get_sql_name(self): arg_str = self.children[0].get_sql_name() return arg_str + ' ISNULL' get_sql_name_no_alias = get_sql_name pass class expr(complex_operand): def __init__(self, ex=None): super(expr, self).__init__() self.children = [] if ex: self.children.append(ex) pass pass def __apply_binary(self, opstr, other): no = self.__class__(self) no.children.append(op(opstr)) no.children.append(normalize_operand(other)) return no def __apply_unary(self, opstr): no = self.__class__() no.children.append(op(opstr)) no.children.append(self) return no def get_sql_name(self): return '(' + string.join(map(lambda x: x.get_sql_name(), self.children), ' ') + ')' def get_sql_name_no_alias(self): return '(' + string.join(map(lambda x: x.get_sql_name_no_alias(), self.children), ' ') + ')' def __str__(self): return self.get_sql_name_no_alias() pass class const(operand): def __init__(self, v): super(const, self).__init__() self.value = v pass def __deepcopy__(self, memo): return const(self.value) @staticmethod def str_escape(txt): return '\'' + txt.replace('\'', '\'\'') + '\'' def get_sql_name(self): import types if self.value is not None: if isinstance(self.value, types.StringTypes): return self.str_escape(self.value) return repr(self.value) return 'NULL' get_sql_name_no_alias = get_sql_name def __str__(self): return repr(self.value) pass # ============================================================ class sys_sym(operand): sym = None def __depcopy__(self, memo): return self def get_sql_name(self): return self.sym get_sql_name_no_alias = get_sql_name def __str__(self): return self.sym pass class all_sym(sys_sym): '''All symbol Used on count(*) to stand as all records.''' sym = '*' pass _a = all_sym() class free_var(sys_sym): '''Free variable (unknow) It is a place holder waiting for user code specifing data when executing the query.''' sym = '?' pass _q = free_var() # ============================================================= class SQL_field_type(operand): '''Define data type of SQL User create a object of a child of SQL_field_type with field name to specify the field name & data type of the field. ''' type_name = 'None' fmt_symbol = '%s' order_seq = 0 def __init__(self, name=None, table=None): super(SQL_field_type, self).__init__() self.field_name = name self.pkey = False self.uniq = False self.autoinc = False self.seq = SQL_field_type.order_seq self.table = table SQL_field_type.order_seq = SQL_field_type.order_seq + 1 pass def get_sql_name(self): return self.table.get_alias() + '.' + self.field_name def get_sql_name_no_alias(self): return self.field_name def primary(self, auto=False): self.pkey = True self.autoinc = auto return self def unique(self): self.uniq = True return self def __str__(self): return self.table.get_alias() + '.' + self.field_name def __deepcopy__(self, memo): return copy(self) @property def schema(self): if self.pkey: ptn = '%s %s primary key' else: ptn = '%s %s' pass if self.autoinc: ptn = ptn + ' autoincrement' pass if self.uniq: ptn = ptn + ' unique' pass return ptn % (self.field_name, self.type_name) pass class int_f(SQL_field_type): type_name = 'integer' fmt_symbol = '%d' pass class float_f(SQL_field_type): type_name = 'float' fmt_symbol = '%f' pass class date_f(SQL_field_type): type_name = 'date' pass class str_f(SQL_field_type): type_name = 'text' pass # ============================================================ # constraints class unique_c(object): def __init__(self, *args): super(unique_c, self).__init__() self.fields = args pass @property def schema(self): from string import join fnames = [f.get_sql_name_no_alias() for f in self.fields] return 'unique(' + join(fnames, ',') + ')' pass def unique(*args): import sys frame = sys._getframe(1) ldict = frame.f_locals for f in args: if not isinstance(f, SQL_field_type): raise ValueError('must be a squence of SQL_field_type instances') pass _uniques = ldict.setdefault('_unique_constraints', []) _uniques.append(unique_c(*args)) ldict['_unique_constraints'] = _uniques pass # ============================================================ class values(object): def __init__(self, table, kv_pairs): super(values, self).__init__() self.table = table self.kv_pairs = kv_pairs pass def gen_update_cmd(self, cond=None): return self.table.gen_update_cmd(self, cond) def gen_insert_cmd(self): return self.table.gen_insert_cmd(self) pass class value_factory(object): def __init__(self, table): super(value_factory, self).__init__() self.table = table pass def __call__(self, **kws): table = self.table for key in kws: if not isinstance(getattr(table, key), SQL_field_type): raise NameError('no such field (%s).' % (key,)) pass return values(self.table, kws.items()) pass # ============================================================ class table_meta(type): def __init__(cls, name, bases, dict): from pythk.types import is_descriptor super(table_meta, cls).__init__(name, bases, dict) cls.tab_name = name for key in dict: o = dict[key] if isinstance(o, SQL_field_type) and o.field_name == None: o.field_name = key pass pass first_obj = cls() for key in dict: o = dict[key] if isinstance(o, SQL_field_type): setattr(cls, key, getattr(first_obj, key)) elif (not key.startswith('_')) and is_descriptor(o): # method is also a descriptor remap_o = cls._remap(o) setattr(cls, key, remap_o) pass pass cls.__objs = [first_obj] pass def __getitem__(cls, idx): objs = cls.__objs if idx >= len(objs): num = (idx - len(objs) + 4) & ~0x3 objs = cls.__objs = objs + ([None] * num) pass obj = objs[idx] if not obj: obj = objs[idx] = cls() pass return obj def __mul__(cls, other): from pythk.types import is_subclass obj = cls[0] if is_subclass(other, table): other = other[0] pass return obj * other @property def _(cls): return cls[0] def __getattr__(cls, name): '''Redirect accessing to variable that not defined in class to first instance object.''' return getattr(cls.__objs[0], name) class _remap(object): '''Remap directly accessing to unbound method to bound method of first instance object, foo_table.__objs[0].''' def __init__(self, src): super(table_meta._remap, self).__init__() self.src = src pass def __get__(self, instance, owner): src = self.src if instance: tgt = src.__get__(instance, owner) else: tgt = src.__get__(owner[0], owner) pass return tgt pass pass class table(object): __metaclass__ = table_meta def __init__(self, name=None, fields=[]): object.__init__(self) cdict = self.__class__.__dict__ predefined = [cdict[key] for key in cdict if isinstance(cdict[key], SQL_field_type)] if name != None: self.tab_name = name pass apply(self.set_fields, tuple(predefined) + tuple(fields)) self.factory = value_factory(self) self.alias = 'A%x' % (id(self),) self.copy_from = self self.copies = 0 pass def get_alias(self): return self.alias def set_fields(self, *fields): self.__fields = [copy(f) for f in fields] for f in self.__fields: setattr(self, f.field_name, f) f.table = self pass def _values_2_assigns(self, values): kv_pairs = values.kv_pairs assigns = [] for k, v in kv_pairs: field = getattr(self, k) if not isinstance(field, SQL_field_type): raise TypeError('%s.%s is not a SQL_field_type object.' % (self.__class__.name, k)) v = normalize_operand(v) ass = '%s=%s' % (field.get_sql_name_no_alias(), v.get_sql_name_no_alias()) assigns.append(ass) pass return assigns def gen_update_cmd(self, values, cond=None): from string import join assigns = self._values_2_assigns(values) a_str = join(assigns, ', ') cmd = 'update %s set %s' % (self.tab_name, a_str) if cond: cmd = cmd + ' where ' + cond.get_sql_name_no_alias() pass return cmd def gen_update_from_vview_cmd(self, values, fields=None, where=None): raise NotImplementedError() def gen_insert_cmd(self, values): from string import join _keys = string.join([getattr(self, k).get_sql_name_no_alias() for k, v in values.kv_pairs], ', ') _values = string.join([normalize_operand(v).get_sql_name_no_alias() for k, v in values.kv_pairs], ', ') cmd = 'insert into %s (%s) values(%s)' % (self.tab_name, _keys, _values) return cmd def gen_delete_cmd(self, where=None): from string import join cmd = 'delete from ' + self.tab_name if where: cmd = cmd + ' where ' + where.get_sql_name_no_alias() pass return cmd def __mul__(self, other): return vview(self, other) def __str__(self): return '' def __deepcopy__(self, memo): new_tab = copy(self) apply(new_tab.set_fields, self.__fields) new_tab.factory = value_factory(self) orig = self.copy_from new_tab.alias = '%s_%04x' % (orig.alias, orig.copies) orig.copies = orig.copies + 1 return new_tab @property def schema(self): fields = list(self.__fields) fields.sort(cmp=cnv.cmp_by_attr('seq')) fields = fields + getattr(self, '_unique_constraints', []) return 'create table ' + self.tab_name + '(\n\t' +string.join([x.schema for x in fields], ',\n\t') + '\n\t);\n' pass # ============================================================ class vview(object): def __init__(self, *args): super(vview, self).__init__() self.__members = list(args) self.__fields = [] self.__cond = [] self.__group = [] self.__order = [] pass def __mul__(self, other): # don't join a table object more than once. self.__members.append(other) return self def __str__(self): return self.make_query_str() def __deepcopy__(self, memo): new_view = vview() new_view.__members = deepcopy(self.__members) new_view.__fields = deepcopy(self.__fields) new_view.__cond = deepcopy(self.__cond) new_view.__group = deepcopy(self.__group) new_view.__order = deepcopy(self.__order) self.__link_fields_to_new_tables(new_view, self.__members, new_view.__members) return new_view def make_query_str(self, restrict=None): tbs = string.join(map(lambda x: x.tab_name + ' ' + x.get_alias(), self.__members), ', ') if self.__fields: fds = string.join(map(lambda x: x.get_sql_name(), self.__fields), ', ') else: fds = '*' pass if self.__cond: conds = [c.get_sql_name() for c in self.__cond] conds = '(' + string.join(conds, ') AND (') + ')' else: conds = '' pass if restrict: if conds: conds = conds + ' AND (' + restrict.get_sql_name() + ')' else: conds = restrict.get_sql_name() pass pass if conds: conds = ' where ' + conds pass if self.__order: orders = ' order by ' + string.join([o.get_sql_name() for o in self.__order], ', ') else: orders = '' pass return 'select ' + fds + ' from ' + tbs + conds + orders def join_on(self, cond): pass def ljoin_on(self, cond): pass def where(self, cond): self.__cond.append(cond) return self def group(self, *fields): self.__group = fields return self def order(self, *order): self.__order = order return self def fields(self, *args): self.__fields = list(args) return self @staticmethod def __link_fields_to_new_tables(view, old_tables, new_tables): from pythk.tourist import isinstance_tourist, isinstance_act from pythk.conveniences import combine_sequences class replace_field_act: def __init__(self, tab_map): self.tab_map = tab_map self.visited = {} pass def __call__(self, visit): field = visit.obj field_name = field.field_name old_table = field.table new_table = self.tab_map[old_table] parent = visit.parent.obj attr_name = visit.name new_field = getattr(new_table, field_name) if isinstance(getattr(parent, attr_name), list): getattr(parent, attr_name)[visit.idx] = new_field else: setattr(parent, attr_name, new_field) pass return None pass tab_map = combine_sequences(old_tables, new_tables) tab_map = dict(tab_map) replacer = replace_field_act(tab_map) def skip(x): return None actions = [] actions.append(isinstance_act((SQL_field_type,), replacer)) actions.append(isinstance_act((table, property), skip)) tour = isinstance_tourist(actions) tour.walk(None, view) pass pass # ============================================================ if __name__ == '__main__': class table1(table): f1 = int_f() f2 = str_f() f3 = float_f() pass class table2(table): pass class table3(table): pass t1 = table1() t2 = table2() t3 = table3() t1.alias = 'A01' t2.alias = 'A02' t3.alias = 'A03' view = t1 * t2 * t3 view.fields(t1.f1, foo(t1.f2)).where((t1.f1 == 'abc') & (foo(t1.f2) < t1.f3) & (t1.f3 > 4) & (t1.f2 == _q)) print 'Virtual view ==============================' print str(view) print 'Schema ====================================' print t1.schema print 'Deepcopy table ============================' print deepcopy(t1).schema print 'Deepcopy vview ============================' print str(deepcopy(view)) print str(view) print 'Gen_update_cmd ============================' print t1.factory(f1=_q, f2='aa', f3=11.5).gen_update_cmd(t1.f2==_q) print 'Gen_insert_cmd ============================' print t1.factory(f1=1, f2='aa', f3=11.5).gen_insert_cmd() pass