Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Downloads stock prices (from Yahoo website) and other prices. 

4""" 

5import os 

6import urllib.request 

7import urllib.error 

8import datetime 

9from io import StringIO 

10import pandas 

11import numpy 

12import requests 

13from pyquickhelper.filehelper import is_file_string 

14 

15 

16class StockPricesException(Exception): 

17 """ 

18 Raised by StockPrices classes. 

19 """ 

20 pass 

21 

22 

23class StockPricesHTTPException(StockPricesException): 

24 """ 

25 Raised by StockPrices classes. 

26 """ 

27 pass 

28 

29 

30class StockPrices: 

31 

32 """ 

33 Defines a class containing stock prices, provides basic functions, 

34 the class uses :epkg:`pandas` to load the data. 

35 

36 .. exref:: 

37 :title: Retrieve stock prices from the Yahoo source 

38 

39 :: 

40 

41 from pyensae.finance import StockPrices 

42 prices = StockPrices(tick="NASDAQ:MSFT") 

43 print(prices.dataframe.head()) 

44 

45 The class loads a stock price from either a url or a folder 

46 where the data was cached. If a filename 

47 ``<folder>/<tick>.<day1>.<day2>.txt`` already exists, 

48 it takes it from here. Otherwise, it downloads it. 

49 

50 A couple of providers have been implemented but it is not 

51 easy to keep them up to date as policies from website 

52 change on a regular basis. 

53 If *url* is ``'yahoo'``, the data will be download using 

54 `CAC 40 <http://finance.yahoo.com/q/cp?s=^FCHI+Components>`_. 

55 The CAC40 composition is described by 

56 `Wikipedia CAC 40 <http://fr.wikipedia.org/wiki/CAC_40>`_. 

57 However `Yahoo Finance <https://fr.finance.yahoo.com/>`_ 

58 introduced the use of cookies in May 2017 

59 and it is not so easy to automate. 

60 The default provider could be 

61 *Google Finance* which has now been integrated into the 

62 search engine. 

63 Tick names depends on the data prodiver. More details: 

64 `European Markets Information <https://www.stockmarketeye.com/users-guide/ticker-symbols-and-data-providers/euro-stocks.html>`_. 

65 You can also go to `quandl <https://www.quandl.com/data/EURONEXT/BNP-Bnp-Paribas-Act-A-BNP>`_ 

66 and get the tick for the module `quandl <https://www.quandl.com/tools/python>`_. 

67 As of May 14th, the following error appears when using 

68 ``url='yahoo'`` which comes from an error in 

69 :epkg:`pandas_reader`:: 

70 

71 ImmediateDeprecationError(DEP_ERROR_MSG.format('Yahoo Daily')) 

72 pandas_datareader.exceptions.ImmediateDeprecationError: 

73 Yahoo Daily has been immediately deprecated due to large breaks in the API without the 

74 introduction of a stable replacement. Pull Requests to re-enable these data 

75 connectors are welcome. 

76 

77 See https://github.com/pydata/pandas-datareader/issues 

78 

79 ``url='yahoo_new'`` should solve the issue. 

80 It relies on :epkg:`yahoo_historial`. 

81 Data can be downloaded for a specific period of time. 

82 If not specified, it takes the largest available. 

83 

84 .. exref:: 

85 :title: Compute the average returns and correlation matrix 

86 

87 :: 

88 

89 import pyensae, pandas 

90 from pyensae.finance import StockPrices 

91 from pyensae.datasource import download_data 

92 

93 # download the CAC 40 composition from my website (for Yahoo) 

94 download_data('cac40_2013_11_11.txt', website='xd') 

95 

96 # download all the prices (if not already done) and store them into files 

97 actions = pandas.read_csv("cac40_2013_11_11.txt", sep="\\t") 

98 

99 # we remove stocks with not enough historical data 

100 stocks = { k:StockPrices(tick = k) for k,v in actions.values } 

101 dates = StockPrices.available_dates(stocks.values()) 

102 stocks = {k:v for k,v in stocks.items() if len(v.missing(dates)) <= 10} 

103 print("nb left", len(stocks)) 

104 

105 # we remove dates with missing prices 

106 dates = StockPrices.available_dates(stocks.values()) 

107 ok = dates[dates["missing"] == 0] 

108 print("all dates before", len(dates), " after:" , len(ok)) 

109 for k in stocks: 

110 stocks[k] = stocks[k].keep_dates(ok) 

111 

112 # we compute correlation matrix and returns 

113 ret, cor = StockPrices.covariance(stocks.values(), cov = False, ret = True) 

114 

115 You should also look at 

116 `pyensae et notebook <http://www.xavierdupre.fr/blog/notebooks/example%20pyensae.html>`_. 

117 If you use `Google Finance <https://www.google.com/finance>`_ 

118 as a provider, the tick name is usually 

119 prefixed by the market places (NASDAQ for example). The export 

120 does not work for all markets places. 

121 Another provider was added, ``yahoo_new`` which delegates the task 

122 of getting data from `Yahoo Finance <https://finance.yahoo.com/>`_ to module 

123 `yahoo-historical <https://github.com/AndrewRPorter/yahoo-historical>`_. 

124 """ 

125 

126 def __init__(self, tick, url="google", folder="cache", 

127 begin=None, end=None, sep=",", 

128 intern=False, use_dtime=False): 

129 """ 

130 @param tick tick name, ex ``NASDAQ:MSFT`` 

131 @param url if yahoo, downloads the data from there if it was not done before 

132 url is possible, ``'google'``, ``'yahoo_new'``, 

133 ``'quandl'`` are predefined values 

134 @param folder cache folder (created if it does not exists 

135 @param begin first day (datetime), see below 

136 @param end last day (datetime), see below 

137 @param sep column separator 

138 @param intern do not use unless you know what to do 

139 (see :meth:`__getitem__ <pyensae.finance.astock.StockPrices.__getitem__>`) 

140 @param use_dtime if True, use DateTime instead of string 

141 """ 

142 if isinstance(url, pandas.DataFrame): 

143 self.datadf = url 

144 self.tickname = tick 

145 if "Date" not in url.columns: 

146 raise StockPricesHTTPException( 

147 "the dataframe does not contain any column 'Date': {0}".format( 

148 ",".join( 

149 _ for _ in url.columns))) 

150 elif isinstance(tick, str) and is_file_string(tick) and os.path.exists(tick): 

151 self.tickname = os.path.split(tick)[-1] 

152 with open(tick, "r") as f: 

153 for line in f.readlines(): 

154 if line.startswith('<!DOCTYPE html PUBLIC'): 

155 raise StockPricesHTTPException( 

156 "pandas cannot parse the file, check your have access to internet: " + str(tick)) 

157 break 

158 try: 

159 self.datadf = pandas.read_csv(tick, sep=sep) 

160 except Exception as e: 

161 with open(tick, "r") as t: 

162 content = t.read() 

163 if "Firewall Authentication" in content: 

164 raise StockPricesException( 

165 "pandas cannot parse the file, check your have access to internet: " + str(tick)) from e 

166 raise 

167 else: 

168 if not os.path.exists(folder): 

169 try: 

170 os.mkdir(folder) 

171 except PermissionError as e: 

172 raise StockPricesException(("PermissionError, unable to create directory '{0}', " + 

173 "check you execute the program in a folder you have " + 

174 "permission to modify ({1})").format(folder, os.getcwd())) from e 

175 self.tickname = tick 

176 

177 if begin is None: 

178 begin = datetime.datetime(2000, 1, 3) 

179 if end is None: 

180 now = datetime.datetime.now() 

181 end = now - datetime.timedelta(1) 

182 

183 sbeg = begin.strftime("%Y-%m-%d") 

184 send = end.strftime("%Y-%m-%d") 

185 name = os.path.join(folder, tick.replace(":", "_").replace("/", "_").replace("\\\\", "_") + 

186 ".{0}.{1}.txt".format(sbeg, send)) 

187 

188 date_format = None 

189 if not os.path.exists(name): 

190 if url == "google": 

191 use_url = True 

192 url_string = "https://finance.google.com/finance/historical?q={0}".format( 

193 self.tickname) 

194 url_string += "&startdate={0}&enddate={1}&output=csv".format( 

195 begin.strftime('%b %d, %Y'), end.strftime('%b %d, %Y')) 

196 url = url_string.replace(" ", "+").replace(",", "%2C") 

197 date_format = "%b-%d-%Y" 

198 elif url == "quandl": 

199 import quandl # pylint: disable=C0415 

200 df = quandl.get( 

201 "EURONEXT/BNP", start_date=begin.strftime('%Y-%m-%d'), end_date=end.strftime('%Y-%m-%d')) 

202 df.reset_index(drop=False).to_csv( 

203 name, sep=sep, index=False) 

204 use_url = False 

205 elif url == 'yahoo_new': 

206 from yahoo_historical import Fetcher 

207 

208 def _get(self, events): 

209 if self.interval not in ["1d", "1wk", "1mo"]: 

210 raise ValueError( 

211 "Incorrect interval: valid intervals are 1d, 1wk, 1mo") 

212 

213 url = self.api_url % ( 

214 self.ticker, self.start, self.end, self.interval, events) 

215 

216 headers = {'User-Agent': ''} 

217 data = requests.get( 

218 url, cookies={"User-agent": "Mozilla/5.0"}, 

219 headers=headers) 

220 content = StringIO(data.content.decode("utf-8")) 

221 return pandas.read_csv(content, sep=",") 

222 

223 # See issue https://github.com/AndrewRPorter/yahoo-historical/issues/19. 

224 Fetcher._get = _get 

225 data = Fetcher(tick, [begin.year, begin.month, begin.day], 

226 [end.year, end.month, end.day]) 

227 df = data.get_historical() 

228 df.to_csv(name, sep=sep, index=False) 

229 use_url = False 

230 elif url in ("yahoo", "google", "fred", "famafrench"): 

231 import pandas_datareader.data as web # pylint: disable=C0415 

232 df = web.DataReader(self.tickname, url, 

233 begin, end).reset_index(drop=False) 

234 df.to_csv(name, sep=sep, index=False) 

235 use_url = False 

236 else: 

237 raise StockPricesHTTPException( 

238 "Unable to download data '{0}' from the following website '{1}'".format(tick, url)) 

239 

240 if use_url: 

241 self.url_ = url 

242 try: 

243 u = urllib.request.urlopen(url) 

244 text = u.read() 

245 u.close() 

246 except urllib.error.HTTPError as e: 

247 raise StockPricesHTTPException( 

248 "HTTPError, unable to load tick '{0}'\nURL: {1}".format(tick, url)) from e 

249 

250 if len(text) < 10: 

251 raise StockPricesHTTPException( 

252 "nothing to download for '{0}' less than 10 downloaded bytes".format(tick)) 

253 

254 try: 

255 f = open(name, "wb") 

256 f.write(text) 

257 f.close() 

258 except PermissionError as e: 

259 raise StockPricesException(("PermissionError, unable to create directory '{0}', " + 

260 "check you execute the program in a folder you have " + 

261 "permission to modify ({1})").format(folder, os.getcwd())) from e 

262 else: 

263 self.url_ = name 

264 

265 try: 

266 self.datadf = pandas.read_csv(name, sep=sep) 

267 except Exception as e: 

268 with open(tick, "r") as t: 

269 content = t.read() 

270 if "Firewall Authentication" in content: 

271 raise StockPricesException( 

272 "pandas cannot parse the file, check your have access to internet '{0}'".format(tick)) from e 

273 raise 

274 

275 if date_format is not None: 

276 self.datadf["Date"] = pandas.to_datetime(self.datadf["Date"]) 

277 self.datadf["Date"] = self.datadf["Date"].apply( 

278 lambda x: x.strftime('%Y-%m-%d')) 

279 self.datadf.to_csv(name, sep=sep, index=False) 

280 

281 if use_dtime: 

282 self.datadf["Date"] = pandas.to_datetime(self.datadf["Date"]) 

283 

284 if not intern: 

285 try: 

286 self.datadf = self.datadf.sort_values("Date") 

287 except ValueError as e: 

288 if "'Date' is both an index level and a column label" in str(e): 

289 vals = self.datadf['Date'] 

290 ind = self.datadf.index 

291 if numpy.array_equal(vals, ind): 

292 self.datadf = self.datadf.sort_index() 

293 else: 

294 raise StockPricesException( 

295 "Columns Date and index are different.") from e 

296 else: 

297 raise 

298 except AttributeError: 

299 self.datadf = self.datadf.sort("Date") 

300 except KeyError as e: 

301 raise StockPricesException("schema: {}".format( 

302 ",".join(self.datadf.columns))) from e 

303 self.datadf.reset_index(drop=True, inplace=True) 

304 self.datadf.set_index("Date", drop=False, inplace=True) 

305 

306 def __getitem__(self, key): 

307 """ 

308 Overloads the ``getitem`` operator to get a @see cl StockPrice object. 

309 

310 @param key key 

311 @return StockPrice 

312 """ 

313 return StockPrices( 

314 self.tick, self.datadf.__getitem__(key), intern=True) 

315 

316 def __len__(self): 

317 """ 

318 @return number of observations 

319 """ 

320 return len(self.datadf) 

321 

322 @property 

323 def shape(self): 

324 """ 

325 @return number of observations 

326 """ 

327 return self.datadf.shape 

328 

329 @property 

330 def tick(self): 

331 """ 

332 Returns the tick name. 

333 """ 

334 return self.tickname 

335 

336 @property 

337 def dataframe(self): 

338 """ 

339 Returns the dataframe. 

340 """ 

341 return self.datadf 

342 

343 def df(self): 

344 """ 

345 Returns the dataframe. 

346 """ 

347 return self.datadf 

348 

349 def FirstDate(self): 

350 """ 

351 Returns the first date. 

352 """ 

353 return self.datadf["Date"].min() 

354 

355 def LastDate(self): 

356 """ 

357 Returns the first date. 

358 """ 

359 return self.datadf["Date"].max() 

360 

361 def missing(self, trading_dates): 

362 """ 

363 Returnq the list of missing dates from an overset of trading dates. 

364 

365 @param trading_dates trading_dates (DataFrame having the column ``Date`` or in the index) 

366 @return missing dates (or None if issues) 

367 """ 

368 da = self.dataframe["Date"] 

369 da2 = {v: 1 for v in da} 

370 

371 if isinstance(trading_dates, dict): 

372 se = trading_dates 

373 else: 

374 se = trading_dates[ 

375 "Date"] if "Date" in trading_dates.columns else trading_dates.index 

376 

377 tbl = [{"Date": v} for v in se if v not in da2] 

378 if len(tbl) > 0: 

379 df = pandas.DataFrame(tbl) 

380 try: 

381 return df.sort_values("Date") 

382 except AttributeError: 

383 return df.sort("Date") 

384 else: 

385 return None 

386 

387 @staticmethod 

388 def available_dates(listStockPrices, missing=True, field="Close"): 

389 """ 

390 Returns the list of values (Open or High or Low or Close or Volume) from each stock 

391 for all the available_dates for a list of stock prices. 

392 

393 A missing date is a date for which there is at least one stock price and one missing stock price. 

394 

395 if ``missing`` is true a column is added which gives the number of missing stock prices for this dates 

396 

397 @param listStockPrices list of StockPrices 

398 @param missing True or False 

399 @param field which field to use to fill the matrix 

400 @return matrix with the available dates for each stock 

401 """ 

402 if field == "ohlc": 

403 field = ["Open", "High", "Low", "Close"] 

404 dates = [] 

405 if isinstance(field, str): 

406 for st in listStockPrices: 

407 lifi = list(st.dataframe.columns) 

408 index = lifi.index(field) 

409 for row in st.dataframe.values: 

410 date = row[0] 

411 dates.append( 

412 {"Date": date, "tick": st.tick, field: row[index]}) 

413 elif isinstance(field, (tuple, list)): 

414 for st in listStockPrices: 

415 lifi = list(st.dataframe.columns) 

416 indexes = [lifi.index(f) for f in field] 

417 for row in st.dataframe.values: 

418 date = row[0] 

419 r = {"Date": date, "tick": st.tick, } 

420 for i, f in zip(indexes, field): 

421 r[f] = row[i] 

422 dates.append(r) 

423 else: 

424 raise TypeError("field must be a string, a tuple or a list") 

425 

426 df = pandas.DataFrame(dates) 

427 if isinstance(field, str): 

428 piv = df.pivot("Date", "tick", field) 

429 elif isinstance(field, (tuple, list)): 

430 pivs = [df.pivot("Date", "tick", f) for f in field] 

431 for fi, piv in zip(field, pivs): 

432 col = [c + "," + fi for c in piv.columns] 

433 piv.columns = col 

434 if len(pivs) == 1: 

435 piv = pivs[0] 

436 else: 

437 piv = pivs[0].merge(pivs[1], how="outer", 

438 left_index=True, right_index=True) 

439 for p in pivs[2:]: 

440 piv = piv.merge( 

441 p, how="outer", left_index=True, right_index=True) 

442 else: 

443 raise TypeError("field must be a string, a tuple or a list") 

444 

445 if missing: 

446 def count_nan(row): 

447 "count nans" 

448 n = 0 

449 for k, v in row.items(): 

450 if k == "Date": 

451 continue 

452 if numpy.isnan(v): 

453 n += 1 

454 return n 

455 piv["missing"] = piv.apply(lambda row: count_nan(row), axis=1) 

456 

457 try: 

458 piv = piv.sort_index() 

459 except AttributeError: 

460 piv = piv.sort() 

461 return piv 

462 

463 def head(self): 

464 """ 

465 usual 

466 """ 

467 return self.dataframe.head() 

468 

469 def tail(self): 

470 """ 

471 usual 

472 """ 

473 return self.dataframe.tail() 

474 

475 def keep_dates(self, trading_dates): 

476 """ 

477 removes undesired dates 

478 

479 @param trading_dates dates 

480 @return new series 

481 """ 

482 da = self.dataframe["Date"] 

483 da2 = {v: 1 for v in da} 

484 

485 if isinstance(trading_dates, dict): 

486 se = trading_dates 

487 else: 

488 se = trading_dates[ 

489 "Date"] if "Date" in trading_dates.columns else trading_dates.index 

490 

491 tbl = {v: 1 for v in se if v in da2} 

492 if len(tbl) > 0: 

493 ave = self.dataframe.apply(lambda row: row["Date"] in tbl, axis=1) 

494 return StockPrices(self.tickname, self.dataframe.loc[ave, :]) 

495 else: 

496 raise StockPricesException("no trading dates left") 

497 

498 def returns(self): 

499 """ 

500 Builds the series of returns. 

501 

502 @param col column to use to compute the returns 

503 @return StockPrices 

504 """ 

505 df = self.dataframe 

506 fd = self.FirstDate() 

507 ld = self.LastDate() 

508 

509 plus = df["Date"] > fd # dates from FirstDate+1 to LastDate 

510 moins = df["Date"] < ld # dates from FirstDate to LastDate-1 

511 

512 res = df.loc[plus, ["Date", "Volume"]] 

513 

514 for k in df.columns: 

515 if k in ["Date", "Volume"]: 

516 continue 

517 m = numpy.array(df.loc[moins, k]) 

518 p = numpy.array(df.loc[plus, k]) 

519 res[k] = (p - m) / m 

520 

521 return StockPrices(self.tickname, res) 

522 

523 @staticmethod 

524 def covariance( 

525 listStockPrices, missing=True, field="Close", cov=True, ret=False): 

526 """ 

527 Computes the covariances matrix (of returns). 

528 

529 @param listStockPrices list of StockPrices 

530 @param field which field to use to fill the matrix 

531 @param cov if True, returns the covariance, otherwise, the correlations 

532 @param ret if True, also add the returns 

533 @return square dataframe or 2 dataframe (returns, correlation) 

534 """ 

535 listStockPrices = [v.returns() for v in listStockPrices] 

536 mat = StockPrices.available_dates(listStockPrices, False, field) 

537 

538 npmat = numpy.array(mat) 

539 cov = numpy.cov( 

540 npmat.transpose()) if cov else numpy.corrcoef( 

541 npmat.transpose()) 

542 names = [v.tick for v in listStockPrices] 

543 ret_mat = pandas.DataFrame(cov, columns=names, index=names) 

544 

545 if ret: 

546 rows = [{"tick": v.tick, "return": v.dataframe[field].mean()} 

547 for v in listStockPrices] 

548 ret = pandas.DataFrame(rows) 

549 ret.set_index("tick", drop=True, inplace=True) 

550 return ret, ret_mat 

551 else: 

552 return ret_mat 

553 

554 def plot(self, begin=None, end=None, 

555 field="Close", date_format=None, 

556 existing=None, axis=1, ax=None, 

557 label_prefix=None, color=None, **args): 

558 """ 

559 See :meth:`draw <pyensae.finance.astock.StockPrices.draw>`. 

560 """ 

561 return StockPrices.draw(self, begin=begin, end=end, 

562 field=field, date_format=date_format, 

563 existing=existing, axis=axis, ax=ax, 

564 label_prefix=label_prefix, color=color, 

565 **args) 

566 

567 @staticmethod 

568 def draw(listStockPrices, begin=None, end=None, 

569 field="Close", date_format=None, 

570 existing=None, axis=1, ax=None, 

571 label_prefix=None, color=None, **args): 

572 """ 

573 Draws a graph showing one or several time series. 

574 The example was taken 

575 `date_demo.py <https://matplotlib.org/examples/api/date_demo.html>`_. 

576 

577 @param listStockPrices list of @see cl StockPrices (or one @see cl StockPrices if it is the only one) 

578 @param begin first date (datetime) or None to take the first one 

579 @param end last included date (datetime) or None to take the last one 

580 @param field Open, High, Low, Close, Adj Close, Volume 

581 @param date_format ``%Y`` or ``%Y-%m`` or ``%Y-%m-%d`` or None if you prefer the function to choose 

582 @param args other arguments to send to ``plt.subplots`` 

583 @param axis 1 or 2, it only works if existing is not None. 

584 If axis is 2, the function draws the curves on the second axis. 

585 @param label_prefix to prefix curve label 

586 @param color curve color 

587 @param args other parameters to give method ``plt.subplots`` 

588 @param ax use existing `axes <http://matplotlib.org/api/axes_api.html>`_ 

589 @return `axes <http://matplotlib.org/api/axes_api.html>`_ 

590 

591 The parameter ``figsize`` of the method 

592 `subplots <https://matplotlib.org/api/pyplot_api.html?highlight=subplots#matplotlib.pyplot.subplots>`_ 

593 can change the graph size (see the example below). 

594 

595 .. exref:: 

596 :title: graph of a financial series 

597 

598 :: 

599 

600 from pyensae.finance import StockPrices 

601 stocks = [ StockPrices("NASDAQ:MSFT", folder = cache), 

602 StockPrices("NASDAQ:GOOGL", folder = cache), 

603 StockPrices("NASDAQ:AAPL", folder = cache)] 

604 fig, ax, plt = StockPrices.draw(stocks) 

605 fig.savefig("image.png") 

606 fig, ax, plt = StockPrices.draw(stocks, begin="2010-01-01", figsize=(16,8)) 

607 plt.show() 

608 

609 You can also chain the graphs and add a series on a second graph: 

610 

611 :: 

612 

613 from pyensae.finance import StockPrices 

614 stock = StockPrices("NASDAQ:MSFT", folder = cache) 

615 stock2 = StockPrices "NASDAQ:GOOGL", folder = cache) 

616 fig, ax, plt = stock.plot(figsize=(16,8)) 

617 fig, ax, plt = stock2.plot(existing=(fig,ax), axis=2) 

618 plt.show() 

619 

620 .. versionchanged:: 1.1 

621 Parameter *existing* was removed and parameter *ax* was added. 

622 If the date overlaps, the method 

623 `autofmt_xdate <https://matplotlib.org/api/figure_api.html#matplotlib.figure.Figure.autofmt_xdate>`_ 

624 should be called. 

625 """ 

626 if isinstance(listStockPrices, StockPrices): 

627 listStockPrices = [listStockPrices] 

628 

629 data = StockPrices.available_dates( 

630 listStockPrices, missing=False, field=field) 

631 if begin is None: 

632 if end is not None: 

633 data = data[data.index <= end] 

634 else: 

635 if end is not None: 

636 data = data[(data.index >= begin) & (data.index <= end)] 

637 else: 

638 data = data[data.index >= begin] 

639 

640 dates = [datetime.datetime.strptime(_, '%Y-%m-%d') for _ in data.index] 

641 begin = dates[0] 

642 end = dates[-1] 

643 

644 def price(x): 

645 "local formatting" 

646 return '%1.2f' % x 

647 

648 import matplotlib.pyplot as plt # pylint: disable=C0415 

649 import matplotlib.dates as mdates # pylint: disable=C0415 

650 

651 if ax is not None: 

652 ex_h, ex_l = ax.get_legend_handles_labels() 

653 ex_l = tuple(ex_l) 

654 ex_h = tuple(ex_h) 

655 if axis == 2: 

656 ax = ax.twinx() 

657 fig = None 

658 else: 

659 if 'label' in args: 

660 args_ = {k: v for k, v in args.items() if k not in ('label', )} 

661 else: 

662 args_ = args 

663 fig, ax = plt.subplots(**args_) 

664 ex_h, ex_l = tuple(), tuple() 

665 

666 curve = [] 

667 if field == "ohlc": 

668 from mplfinance.original_flavor import candlestick_ohlc # pylint: disable=E0401 

669 ohlc = list(list(data.iloc[i, :4]) 

670 for i in range(0, data.shape[0])) 

671 ohlc = [[mdates.date2num(t)] + v for t, v in zip(dates, ohlc)] 

672 candlestick_ohlc(ax, ohlc, colorup="g") 

673 else: 

674 if label_prefix is None: 

675 label_prefix = "" 

676 add_args = {} 

677 if color: 

678 add_args['c'] = color 

679 for stock in data.columns: 

680 if axis == 2: 

681 curve.append( 

682 ax.plot(dates, data[stock], "r", linestyle='solid', 

683 label=label_prefix + str(stock), **add_args)) 

684 else: 

685 curve.append( 

686 ax.plot(dates, data[stock], linestyle='solid', c=color, 

687 label=label_prefix + str(stock), **add_args)) 

688 

689 if existing is None: 

690 ax.format_xdata = mdates.DateFormatter('%Y-%m-%d') 

691 if len(dates) < 30: 

692 days = mdates.DayLocator() 

693 ax.xaxis.set_major_locator(days) 

694 ax.xaxis.set_minor_locator(days) 

695 if date_format is not None: 

696 fmt = mdates.DateFormatter(date_format) 

697 ax.xaxis.set_major_formatter(fmt) 

698 else: 

699 ax.xaxis.set_major_formatter( 

700 mdates.DateFormatter("%Y-%m-%d")) 

701 elif len(dates) < 500: 

702 months = mdates.MonthLocator() 

703 days = mdates.DayLocator() 

704 ax.xaxis.set_major_locator(months) 

705 ax.xaxis.set_minor_locator(days) 

706 ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m")) 

707 if date_format is not None: 

708 fmt = mdates.DateFormatter(date_format) 

709 ax.xaxis.set_major_formatter(fmt) 

710 else: 

711 ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m")) 

712 else: 

713 years = mdates.YearLocator() 

714 months = mdates.MonthLocator() 

715 ax.xaxis.set_major_locator(years) 

716 ax.xaxis.set_minor_locator(months) 

717 if date_format is not None: 

718 fmt = mdates.DateFormatter(date_format) 

719 ax.xaxis.set_major_formatter(fmt) 

720 else: 

721 ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y")) 

722 

723 ax.set_xlim(begin, end) 

724 ax.format_ydata = price 

725 if fig is not None: 

726 fig.autofmt_xdate() 

727 

728 if axis == 2: 

729 if isinstance(curve, list): 

730 curve = [_[0] for _ in curve] 

731 ax.legend(ex_h + tuple(curve), ex_l + tuple(data.columns)) 

732 else: 

733 ax.grid(True) 

734 ax.legend(ex_l + tuple(data.columns)) 

735 

736 return ax 

737 

738 def to_csv(self, filename, sep="\t", index=False, **params): 

739 """ 

740 Saves the file in text format, 

741 see `to_csv <https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.to_csv.html>`_ 

742 

743 @param filename filename 

744 @param sep separator 

745 @param index to keep or drop the index 

746 @param params other parameters 

747 """ 

748 self.dataframe.to_csv(filename, sep=sep, index=index, **params) 

749 

750 def to_excel(self, excel_writer, **params): 

751 """ 

752 Saves the file in Excel format, 

753 see `to_excel <https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.to_excel.html>`_ 

754 """ 

755 self.dataframe.to_excel(excel_writer, **params)