'''
sql.py: Host SQL on Python Language
table * table ==> vview
vview * table ==> vview
field + field ==> expr
expr + field ==> expr
'''
#
# XXX: Please keep all sequence type attributes as list.
# It eases our task that deepcopy an expr.
# XXX: Never duplicate SQL_field_type objects when performing
# deepcopy. SQL_field_type objects should be reused, or
# remap to SQL_field_type objects of new tables.
#
import string
from copy import copy, deepcopy
from pythk import conveniences as cnv
def get_tables(*args):
return tuple([table(x) for x in args])
# ============================================================
class operand(object):
def __init__(self):
super(operand, self).__init__()
pass
def get_sql_name(self):
raise NotImplementedError()
def get_sql_name_no_alias(self):
raise NotImplementedError()
def __str__(self):
return self.op
def __apply_binary(self, opstr, other):
no = expr(self)
no.children.append(op(opstr))
no.children.append(normalize_operand(other))
return no
def __apply_unary(self, opstr):
no = expr()
no.children.append(op(opstr))
no.children.append(self)
return no
def __add__(self, other):
return self.__apply_binary('+', other)
def __sub__(self, other):
return self.__apply_binary('-', other)
def __mul__(self, other):
return self.__apply_binary('*', other)
def __add__(self, other):
return self.__apply_binary('/', other)
def __eq__(self, other):
return self.__apply_binary('=', other)
def __ne__(self, other):
return self.__apply_binary('!=', other)
def __gt__(self, other):
return self.__apply_binary('>', other)
def __lt__(self, other):
return self.__apply_binary('<', other)
def __and__(self, other):
return self.__apply_binary('AND', other)
def __or__(self, other):
return self.__apply_binary('OR', other)
def __invert__(self):
return self.__apply_unary('NOT')
pass
class complex_operand(operand):
children = []
def __deepcopy__(self, memo):
new_operand = copy(self)
def only_dup_complex(child):
if not isinstance(child, complex_operand):
return child
return deepcopy(child)
new_operand.children = [only_dup_complex(child) for child in self.children]
return new_operand
pass
def normalize_operand(o):
if isinstance(o, operand):
return o
return const(o)
class op(object):
def __init__(self, op_name):
super(op, self).__init__()
self.op_name = op_name
pass
def get_sql_name(self):
return self.op_name
get_sql_name_no_alias = get_sql_name
def __str__(self):
return self.op_name
pass
class func_generic(complex_operand):
def __init__(self, *children):
super(func_generic, self).__init__()
self.children = children
pass
def get_sql_name(self):
args = [c.get_sql_name() for c in self.children]
arg_str = string.join(args, ', ')
return self.__class__.__name__ + '(' + arg_str + ')'
def get_sql_name_no_alias(self):
args = [c.get_sql_name_no_alias() for c in self.children]
arg_str = string.join(args, ', ')
return self.__class__.__name__ + '(' + arg_str + ')'
__str__ = get_sql_name
pass
class foo(func_generic):
pass
class isnull(func_generic):
def get_sql_name(self):
arg_str = self.children[0].get_sql_name()
return arg_str + ' ISNULL'
get_sql_name_no_alias = get_sql_name
pass
class expr(complex_operand):
def __init__(self, ex=None):
super(expr, self).__init__()
self.children = []
if ex:
self.children.append(ex)
pass
pass
def __apply_binary(self, opstr, other):
no = self.__class__(self)
no.children.append(op(opstr))
no.children.append(normalize_operand(other))
return no
def __apply_unary(self, opstr):
no = self.__class__()
no.children.append(op(opstr))
no.children.append(self)
return no
def get_sql_name(self):
return '(' + string.join(map(lambda x: x.get_sql_name(), self.children), ' ') + ')'
def get_sql_name_no_alias(self):
return '(' + string.join(map(lambda x: x.get_sql_name_no_alias(), self.children), ' ') + ')'
def __str__(self):
return self.get_sql_name_no_alias()
pass
class const(operand):
def __init__(self, v):
super(const, self).__init__()
self.value = v
pass
def __deepcopy__(self, memo):
return const(self.value)
@staticmethod
def str_escape(txt):
return '\'' + txt.replace('\'', '\'\'') + '\''
def get_sql_name(self):
import types
if self.value is not None:
if isinstance(self.value, types.StringTypes):
return self.str_escape(self.value)
return repr(self.value)
return 'NULL'
get_sql_name_no_alias = get_sql_name
def __str__(self):
return repr(self.value)
pass
# ============================================================
class sys_sym(operand):
sym = None
def __depcopy__(self, memo):
return self
def get_sql_name(self):
return self.sym
get_sql_name_no_alias = get_sql_name
def __str__(self):
return self.sym
pass
class all_sym(sys_sym):
'''All symbol
Used on count(*) to stand as all records.'''
sym = '*'
pass
_a = all_sym()
class free_var(sys_sym):
'''Free variable (unknow)
It is a place holder waiting for user code specifing
data when executing the query.'''
sym = '?'
pass
_q = free_var()
# =============================================================
class SQL_field_type(operand):
'''Define data type of SQL
User create a object of a child of SQL_field_type with field name to
specify the field name & data type of the field.
'''
type_name = 'None'
fmt_symbol = '%s'
order_seq = 0
def __init__(self, name=None, table=None):
super(SQL_field_type, self).__init__()
self.field_name = name
self.pkey = False
self.uniq = False
self.autoinc = False
self.seq = SQL_field_type.order_seq
self.table = table
SQL_field_type.order_seq = SQL_field_type.order_seq + 1
pass
def get_sql_name(self):
return self.table.get_alias() + '.' + self.field_name
def get_sql_name_no_alias(self):
return self.field_name
def primary(self, auto=False):
self.pkey = True
self.autoinc = auto
return self
def unique(self):
self.uniq = True
return self
def __str__(self):
return self.table.get_alias() + '.' + self.field_name
def __deepcopy__(self, memo):
return copy(self)
@property
def schema(self):
if self.pkey:
ptn = '%s %s primary key'
else:
ptn = '%s %s'
pass
if self.autoinc:
ptn = ptn + ' autoincrement'
pass
if self.uniq:
ptn = ptn + ' unique'
pass
return ptn % (self.field_name, self.type_name)
pass
class int_f(SQL_field_type):
type_name = 'integer'
fmt_symbol = '%d'
pass
class float_f(SQL_field_type):
type_name = 'float'
fmt_symbol = '%f'
pass
class date_f(SQL_field_type):
type_name = 'date'
pass
class str_f(SQL_field_type):
type_name = 'text'
pass
# ============================================================
# constraints
class unique_c(object):
def __init__(self, *args):
super(unique_c, self).__init__()
self.fields = args
pass
@property
def schema(self):
from string import join
fnames = [f.get_sql_name_no_alias() for f in self.fields]
return 'unique(' + join(fnames, ',') + ')'
pass
def unique(*args):
import sys
frame = sys._getframe(1)
ldict = frame.f_locals
for f in args:
if not isinstance(f, SQL_field_type):
raise ValueError('must be a squence of SQL_field_type instances')
pass
_uniques = ldict.setdefault('_unique_constraints', [])
_uniques.append(unique_c(*args))
ldict['_unique_constraints'] = _uniques
pass
# ============================================================
class values(object):
def __init__(self, table, kv_pairs):
super(values, self).__init__()
self.table = table
self.kv_pairs = kv_pairs
pass
def gen_update_cmd(self, cond=None):
return self.table.gen_update_cmd(self, cond)
def gen_insert_cmd(self):
return self.table.gen_insert_cmd(self)
pass
class value_factory(object):
def __init__(self, table):
super(value_factory, self).__init__()
self.table = table
pass
def __call__(self, **kws):
table = self.table
for key in kws:
if not isinstance(getattr(table, key), SQL_field_type):
raise NameError('no such field (%s).' % (key,))
pass
return values(self.table, kws.items())
pass
# ============================================================
class table_meta(type):
def __init__(cls, name, bases, dict):
from pythk.types import is_descriptor
super(table_meta, cls).__init__(name, bases, dict)
cls.tab_name = name
for key in dict:
o = dict[key]
if isinstance(o, SQL_field_type) and o.field_name == None:
o.field_name = key
pass
pass
first_obj = cls()
for key in dict:
o = dict[key]
if isinstance(o, SQL_field_type):
setattr(cls, key, getattr(first_obj, key))
elif (not key.startswith('_')) and is_descriptor(o):
# method is also a descriptor
remap_o = cls._remap(o)
setattr(cls, key, remap_o)
pass
pass
cls.__objs = [first_obj]
pass
def __getitem__(cls, idx):
objs = cls.__objs
if idx >= len(objs):
num = (idx - len(objs) + 4) & ~0x3
objs = cls.__objs = objs + ([None] * num)
pass
obj = objs[idx]
if not obj:
obj = objs[idx] = cls()
pass
return obj
def __mul__(cls, other):
from pythk.types import is_subclass
obj = cls[0]
if is_subclass(other, table):
other = other[0]
pass
return obj * other
@property
def _(cls):
return cls[0]
def __getattr__(cls, name):
'''Redirect accessing to variable that not defined in class
to first instance object.'''
return getattr(cls.__objs[0], name)
class _remap(object):
'''Remap directly accessing to unbound method to
bound method of first instance object, foo_table.__objs[0].'''
def __init__(self, src):
super(table_meta._remap, self).__init__()
self.src = src
pass
def __get__(self, instance, owner):
src = self.src
if instance:
tgt = src.__get__(instance, owner)
else:
tgt = src.__get__(owner[0], owner)
pass
return tgt
pass
pass
class table(object):
__metaclass__ = table_meta
def __init__(self, name=None, fields=[]):
object.__init__(self)
cdict = self.__class__.__dict__
predefined = [cdict[key] for key in cdict if isinstance(cdict[key], SQL_field_type)]
if name != None:
self.tab_name = name
pass
apply(self.set_fields, tuple(predefined) + tuple(fields))
self.factory = value_factory(self)
self.alias = 'A%x' % (id(self),)
self.copy_from = self
self.copies = 0
pass
def get_alias(self):
return self.alias
def set_fields(self, *fields):
self.__fields = [copy(f) for f in fields]
for f in self.__fields:
setattr(self, f.field_name, f)
f.table = self
pass
def _values_2_assigns(self, values):
kv_pairs = values.kv_pairs
assigns = []
for k, v in kv_pairs:
field = getattr(self, k)
if not isinstance(field, SQL_field_type):
raise TypeError('%s.%s is not a SQL_field_type object.' % (self.__class__.name, k))
v = normalize_operand(v)
ass = '%s=%s' % (field.get_sql_name_no_alias(), v.get_sql_name_no_alias())
assigns.append(ass)
pass
return assigns
def gen_update_cmd(self, values, cond=None):
from string import join
assigns = self._values_2_assigns(values)
a_str = join(assigns, ', ')
cmd = 'update %s set %s' % (self.tab_name, a_str)
if cond:
cmd = cmd + ' where ' + cond.get_sql_name_no_alias()
pass
return cmd
def gen_update_from_vview_cmd(self, values, fields=None, where=None):
raise NotImplementedError()
def gen_insert_cmd(self, values):
from string import join
_keys = string.join([getattr(self, k).get_sql_name_no_alias() for k, v in values.kv_pairs], ', ')
_values = string.join([normalize_operand(v).get_sql_name_no_alias() for k, v in values.kv_pairs], ', ')
cmd = 'insert into %s (%s) values(%s)' % (self.tab_name, _keys, _values)
return cmd
def gen_delete_cmd(self, where=None):
from string import join
cmd = 'delete from ' + self.tab_name
if where:
cmd = cmd + ' where ' + where.get_sql_name_no_alias()
pass
return cmd
def __mul__(self, other):
return vview(self, other)
def __str__(self):
return '<table ' + self.tab_name + '>'
def __deepcopy__(self, memo):
new_tab = copy(self)
apply(new_tab.set_fields, self.__fields)
new_tab.factory = value_factory(self)
orig = self.copy_from
new_tab.alias = '%s_%04x' % (orig.alias, orig.copies)
orig.copies = orig.copies + 1
return new_tab
@property
def schema(self):
fields = list(self.__fields)
fields.sort(cmp=cnv.cmp_by_attr('seq'))
fields = fields + getattr(self, '_unique_constraints', [])
return 'create table ' + self.tab_name + '(\n\t' +string.join([x.schema for x in fields], ',\n\t') + '\n\t);\n'
pass
# ============================================================
class vview(object):
def __init__(self, *args):
super(vview, self).__init__()
self.__members = list(args)
self.__fields = []
self.__cond = []
self.__group = []
self.__order = []
pass
def __mul__(self, other):
# don't join a table object more than once.
self.__members.append(other)
return self
def __str__(self):
return self.make_query_str()
def __deepcopy__(self, memo):
new_view = vview()
new_view.__members = deepcopy(self.__members)
new_view.__fields = deepcopy(self.__fields)
new_view.__cond = deepcopy(self.__cond)
new_view.__group = deepcopy(self.__group)
new_view.__order = deepcopy(self.__order)
self.__link_fields_to_new_tables(new_view, self.__members, new_view.__members)
return new_view
def make_query_str(self, restrict=None):
tbs = string.join(map(lambda x: x.tab_name + ' ' + x.get_alias(), self.__members), ', ')
if self.__fields:
fds = string.join(map(lambda x: x.get_sql_name(), self.__fields), ', ')
else:
fds = '*'
pass
if self.__cond:
conds = [c.get_sql_name() for c in self.__cond]
conds = '(' + string.join(conds, ') AND (') + ')'
else:
conds = ''
pass
if restrict:
if conds:
conds = conds + ' AND (' + restrict.get_sql_name() + ')'
else:
conds = restrict.get_sql_name()
pass
pass
if conds:
conds = ' where ' + conds
pass
if self.__order:
orders = ' order by ' + string.join([o.get_sql_name() for o in self.__order], ', ')
else:
orders = ''
pass
return 'select ' + fds + ' from ' + tbs + conds + orders
def join_on(self, cond):
pass
def ljoin_on(self, cond):
pass
def where(self, cond):
self.__cond.append(cond)
return self
def group(self, *fields):
self.__group = fields
return self
def order(self, *order):
self.__order = order
return self
def fields(self, *args):
self.__fields = list(args)
return self
@staticmethod
def __link_fields_to_new_tables(view, old_tables, new_tables):
from pythk.tourist import isinstance_tourist, isinstance_act
from pythk.conveniences import combine_sequences
class replace_field_act:
def __init__(self, tab_map):
self.tab_map = tab_map
self.visited = {}
pass
def __call__(self, visit):
field = visit.obj
field_name = field.field_name
old_table = field.table
new_table = self.tab_map[old_table]
parent = visit.parent.obj
attr_name = visit.name
new_field = getattr(new_table, field_name)
if isinstance(getattr(parent, attr_name), list):
getattr(parent, attr_name)[visit.idx] = new_field
else:
setattr(parent, attr_name, new_field)
pass
return None
pass
tab_map = combine_sequences(old_tables, new_tables)
tab_map = dict(tab_map)
replacer = replace_field_act(tab_map)
def skip(x):
return None
actions = []
actions.append(isinstance_act((SQL_field_type,), replacer))
actions.append(isinstance_act((table, property), skip))
tour = isinstance_tourist(actions)
tour.walk(None, view)
pass
pass
# ============================================================
if __name__ == '__main__':
class table1(table):
f1 = int_f()
f2 = str_f()
f3 = float_f()
pass
class table2(table):
pass
class table3(table):
pass
t1 = table1()
t2 = table2()
t3 = table3()
t1.alias = 'A01'
t2.alias = 'A02'
t3.alias = 'A03'
view = t1 * t2 * t3
view.fields(t1.f1, foo(t1.f2)).where((t1.f1 == 'abc') & (foo(t1.f2) < t1.f3) & (t1.f3 > 4) & (t1.f2 == _q))
print 'Virtual view =============================='
print str(view)
print 'Schema ===================================='
print t1.schema
print 'Deepcopy table ============================'
print deepcopy(t1).schema
print 'Deepcopy vview ============================'
print str(deepcopy(view))
print str(view)
print 'Gen_update_cmd ============================'
print t1.factory(f1=_q, f2='aa', f3=11.5).gen_update_cmd(t1.f2==_q)
print 'Gen_insert_cmd ============================'
print t1.factory(f1=1, f2='aa', f3=11.5).gen_insert_cmd()
pass
syntax highlighted by Code2HTML, v. 0.9.1