Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add tz info to all naive datetime object #120

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/order_margins.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@

margin_detail = kite.order_margins(order_param_multi)
logging.info("Required margin for order_list: {}".format(margin_detail))
# Compact margin response
margin_detail_compt = kite.order_margins(order_param_multi, mode='compact')
logging.info("Required margin for order_list in compact form: {}".format(margin_detail_compt))

# Basket orders
order_param_basket = [
Expand Down
79 changes: 34 additions & 45 deletions kiteconnect/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from six.moves.urllib.parse import urljoin
import csv
import json
import dateutil.parser
from dateutil.parser import parse as datetimeparse
from dateutil.tz import tzoffset
from dateutil.utils import default_tzinfo
import hashlib
import logging
import datetime
Expand Down Expand Up @@ -263,18 +265,15 @@ def generate_session(self, request_token, api_secret):
h = hashlib.sha256(self.api_key.encode("utf-8") + request_token.encode("utf-8") + api_secret.encode("utf-8"))
checksum = h.hexdigest()

resp = self._post("api.token", params={
resp = self._format_response(self._post("api.token", params={
"api_key": self.api_key,
"request_token": request_token,
"checksum": checksum
})
}))

if "access_token" in resp:
self.set_access_token(resp["access_token"])

if resp["login_time"] and len(resp["login_time"]) == 19:
ranjanrak marked this conversation as resolved.
Show resolved Hide resolved
resp["login_time"] = dateutil.parser.parse(resp["login_time"])

return resp

def invalidate_access_token(self, access_token=None):
Expand Down Expand Up @@ -413,9 +412,10 @@ def _format_response(self, data):

for item in _list:
# Convert date time string to datetime object
for field in ["order_timestamp", "exchange_timestamp", "created", "last_instalment", "fill_timestamp", "timestamp", "last_trade_time"]:
if item.get(field) and len(item[field]) == 19:
item[field] = dateutil.parser.parse(item[field])
for field in ["order_timestamp", "exchange_timestamp", "created", "last_instalment", "fill_timestamp",
"timestamp", "last_trade_time", "login_time", "expiry", "last_price_date"]:
if item.get(field) and self.is_timestamp(item[field]):
item[field] = self.parseDateTime(item[field])

return _list[0] if type(data) == dict else _list

Expand Down Expand Up @@ -650,7 +650,7 @@ def _format_historical(self, data):
records = []
for d in data["candles"]:
record = {
"date": dateutil.parser.parse(d[0]),
"date": self.parseDateTime(d[0]),
"open": d[1],
"high": d[2],
"low": d[3],
Expand Down Expand Up @@ -771,13 +771,14 @@ def delete_gtt(self, trigger_id):
"""Delete a GTT order."""
return self._delete("gtt.delete", url_args={"trigger_id": trigger_id})

def order_margins(self, params):
def order_margins(self, params, mode=None):
"""
Calculate margins for requested order list considering the existing positions and open orders

- `params` is list of orders to retrive margins detail
- `mode` is margin response mode type. compact - Compact mode will only give the total margins
"""
return self._post("order.margins", params=params, is_json=True)
return self._post("order.margins", params=params, is_json=True, query_params={'mode': mode})

def basket_order_margins(self, params, consider_positions=True, mode=None):
"""
Expand All @@ -804,50 +805,33 @@ def _parse_instruments(self, data):
if not PY2 and type(d) == bytes:
d = data.decode("utf-8").strip()

records = []
reader = csv.DictReader(StringIO(d))

for row in reader:
row["instrument_token"] = int(row["instrument_token"])
row["last_price"] = float(row["last_price"])
row["strike"] = float(row["strike"])
row["tick_size"] = float(row["tick_size"])
row["lot_size"] = int(row["lot_size"])

# Parse date
if len(row["expiry"]) == 10:
ranjanrak marked this conversation as resolved.
Show resolved Hide resolved
row["expiry"] = dateutil.parser.parse(row["expiry"]).date()

records.append(row)

return records
return self._format_response(list(reader))

def _parse_mf_instruments(self, data):
# decode to string for Python 3
d = data
if not PY2 and type(d) == bytes:
d = data.decode("utf-8").strip()

records = []
reader = csv.DictReader(StringIO(d))

for row in reader:
row["minimum_purchase_amount"] = float(row["minimum_purchase_amount"])
row["purchase_amount_multiplier"] = float(row["purchase_amount_multiplier"])
row["minimum_additional_purchase_amount"] = float(row["minimum_additional_purchase_amount"])
row["minimum_redemption_quantity"] = float(row["minimum_redemption_quantity"])
row["redemption_quantity_multiplier"] = float(row["redemption_quantity_multiplier"])
row["purchase_allowed"] = bool(int(row["purchase_allowed"]))
row["redemption_allowed"] = bool(int(row["redemption_allowed"]))
row["last_price"] = float(row["last_price"])
return self._format_response(list(reader))

# Parse date
if len(row["last_price_date"]) == 10:
row["last_price_date"] = dateutil.parser.parse(row["last_price_date"]).date()

records.append(row)
def is_timestamp(self, string):
"""Checks if string is timestamp"""
try:
datetimeparse(string)
return True
except ValueError:
return False

return records
def parseDateTime(self, string):
"""Set default timezone to IST for naive time object"""
# Default timezone for all datetime object
default_tz = tzoffset("Asia/Kolkata", 19800)
return default_tzinfo(datetimeparse(string), default_tz)

def _user_agent(self):
return (__title__ + "-python/").capitalize() + __version__
Expand Down Expand Up @@ -929,8 +913,13 @@ def _request(self, route, method, url_args=None, params=None, is_json=False, que
self.session_expiry_hook()

# native Kite errors
exp = getattr(ex, data.get("error_type"), ex.GeneralException)
raise exp(data["message"], code=r.status_code)
# mf error response don't have error_type field
if data.get("error_type"):
exp = getattr(ex, data.get("error_type"), ex.GeneralException)
raise exp(data["message"], code=r.status_code)
else:
# Throw general exception for such undefined error type
raise ex.GeneralException(data["message"], code=r.status_code)

return data["data"]
elif "csv" in r.headers["content-type"]:
Expand Down