Skip to content

Commit 42e4f5e

Browse files
committed
Add types to tree.Tree
1 parent 6500844 commit 42e4f5e

File tree

5 files changed

+44
-34
lines changed

5 files changed

+44
-34
lines changed

‎git/index/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,7 @@ def write_tree(self) -> Tree:
568568
# note: additional deserialization could be saved if write_tree_from_cache
569569
# would return sorted tree entries
570570
root_tree = Tree(self.repo, binsha, path='')
571-
root_tree._cache = tree_items
571+
root_tree._cache = tree_items # type: ignore
572572
return root_tree
573573

574574
def _process_diff_args(self, args: List[Union[str, diff.Diffable, object]]

‎git/index/fun.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def write_tree_from_cache(entries: List[IndexEntry], odb, sl: slice, si: int = 0
293293
# finally create the tree
294294
sio = BytesIO()
295295
tree_to_stream(tree_items, sio.write) # converts bytes of each item[0] to str
296-
tree_items_stringified = cast(List[Tuple[str, int, str]], tree_items) # type: List[Tuple[str, int, str]]
296+
tree_items_stringified = cast(List[Tuple[str, int, str]], tree_items)
297297
sio.seek(0)
298298

299299
istream = odb.store(IStream(str_tree_type, len(sio.getvalue()), sio))

‎git/objects/fun.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,19 @@
11
"""Module with functions which are supposed to be as fast as possible"""
22
from stat import S_ISDIR
3+
34
from git.compat import (
45
safe_decode,
56
defenc
67
)
78

9+
# typing ----------------------------------------------
10+
11+
from typing import List, Tuple
12+
13+
14+
# ---------------------------------------------------
15+
16+
817
__all__ = ('tree_to_stream', 'tree_entries_from_data', 'traverse_trees_recursive',
918
'traverse_tree_recursive')
1019

@@ -38,7 +47,7 @@ def tree_to_stream(entries, write):
3847
# END for each item
3948

4049

41-
def tree_entries_from_data(data):
50+
def tree_entries_from_data(data: bytes) -> List[Tuple[bytes, int, str]]:
4251
"""Reads the binary representation of a tree and returns tuples of Tree items
4352
:param data: data block with tree data (as bytes)
4453
:return: list(tuple(binsha, mode, tree_relative_path), ...)"""
@@ -72,8 +81,8 @@ def tree_entries_from_data(data):
7281

7382
# default encoding for strings in git is utf8
7483
# Only use the respective unicode object if the byte stream was encoded
75-
name = data[ns:i]
76-
name = safe_decode(name)
84+
name_bytes = data[ns:i]
85+
name = safe_decode(name_bytes)
7786

7887
# byte is NULL, get next 20
7988
i += 1

‎git/objects/tree.py

+28-28
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,23 @@
2020

2121
# typing -------------------------------------------------
2222

23-
from typing import Iterable, Iterator, Tuple, Union, cast, TYPE_CHECKING
23+
from typing import Callable, Dict, Iterable, Iterator, List, Tuple, Type, Union, cast, TYPE_CHECKING
24+
25+
from git.types import PathLike
2426

2527
if TYPE_CHECKING:
28+
from git.repo import Repo
2629
from io import BytesIO
2730

2831
#--------------------------------------------------------
2932

3033

31-
cmp = lambda a, b: (a > b) - (a < b)
34+
cmp: Callable[[int, int], int] = lambda a, b: (a > b) - (a < b)
3235

3336
__all__ = ("TreeModifier", "Tree")
3437

3538

36-
def git_cmp(t1, t2):
39+
def git_cmp(t1: 'Tree', t2: 'Tree') -> int:
3740
a, b = t1[2], t2[2]
3841
len_a, len_b = len(a), len(b)
3942
min_len = min(len_a, len_b)
@@ -45,9 +48,9 @@ def git_cmp(t1, t2):
4548
return len_a - len_b
4649

4750

48-
def merge_sort(a, cmp):
51+
def merge_sort(a: List[int], cmp: Callable[[int, int], int]) -> None:
4952
if len(a) < 2:
50-
return
53+
return None
5154

5255
mid = len(a) // 2
5356
lefthalf = a[:mid]
@@ -182,29 +185,29 @@ class Tree(IndexObject, diff.Diffable, util.Traversable, util.Serializable):
182185
symlink_id = 0o12
183186
tree_id = 0o04
184187

185-
_map_id_to_type = {
188+
_map_id_to_type: Dict[int, Union[Type[Submodule], Type[Blob], Type['Tree']]] = {
186189
commit_id: Submodule,
187190
blob_id: Blob,
188191
symlink_id: Blob
189192
# tree id added once Tree is defined
190193
}
191194

192-
def __init__(self, repo, binsha, mode=tree_id << 12, path=None):
195+
def __init__(self, repo: 'Repo', binsha: bytes, mode: int = tree_id << 12, path: Union[PathLike, None] = None):
193196
super(Tree, self).__init__(repo, binsha, mode, path)
194197

195198
@classmethod
196199
def _get_intermediate_items(cls, index_object: 'Tree', # type: ignore
197-
) -> Tuple['Tree', ...]:
200+
) -> Union[Tuple['Tree', ...], Tuple[()]]:
198201
if index_object.type == "tree":
199202
index_object = cast('Tree', index_object)
200203
return tuple(index_object._iter_convert_to_object(index_object._cache))
201204
return ()
202205

203-
def _set_cache_(self, attr):
206+
def _set_cache_(self, attr: str) -> None:
204207
if attr == "_cache":
205208
# Set the data when we need it
206209
ostream = self.repo.odb.stream(self.binsha)
207-
self._cache = tree_entries_from_data(ostream.read())
210+
self._cache: List[Tuple[bytes, int, str]] = tree_entries_from_data(ostream.read())
208211
else:
209212
super(Tree, self)._set_cache_(attr)
210213
# END handle attribute
@@ -221,7 +224,7 @@ def _iter_convert_to_object(self, iterable: Iterable[Tuple[bytes, int, str]]
221224
raise TypeError("Unknown mode %o found in tree data for path '%s'" % (mode, path)) from e
222225
# END for each item
223226

224-
def join(self, file):
227+
def join(self, file: str) -> Union[Blob, 'Tree', Submodule]:
225228
"""Find the named object in this tree's contents
226229
:return: ``git.Blob`` or ``git.Tree`` or ``git.Submodule``
227230
@@ -254,26 +257,22 @@ def join(self, file):
254257
raise KeyError(msg % file)
255258
# END handle long paths
256259

257-
def __div__(self, file):
258-
"""For PY2 only"""
259-
return self.join(file)
260-
261-
def __truediv__(self, file):
260+
def __truediv__(self, file: str) -> Union['Tree', Blob, Submodule]:
262261
"""For PY3 only"""
263262
return self.join(file)
264263

265264
@property
266-
def trees(self):
265+
def trees(self) -> List['Tree']:
267266
""":return: list(Tree, ...) list of trees directly below this tree"""
268267
return [i for i in self if i.type == "tree"]
269268

270269
@property
271-
def blobs(self):
270+
def blobs(self) -> List['Blob']:
272271
""":return: list(Blob, ...) list of blobs directly below this tree"""
273272
return [i for i in self if i.type == "blob"]
274273

275274
@property
276-
def cache(self):
275+
def cache(self) -> TreeModifier:
277276
"""
278277
:return: An object allowing to modify the internal cache. This can be used
279278
to change the tree's contents. When done, make sure you call ``set_done``
@@ -289,16 +288,16 @@ def traverse(self, predicate=lambda i, d: True,
289288
return super(Tree, self).traverse(predicate, prune, depth, branch_first, visit_once, ignore_self)
290289

291290
# List protocol
292-
def __getslice__(self, i, j):
291+
def __getslice__(self, i: int, j: int) -> List[Union[Blob, 'Tree', Submodule]]:
293292
return list(self._iter_convert_to_object(self._cache[i:j]))
294293

295-
def __iter__(self):
294+
def __iter__(self) -> Iterator[Union[Blob, 'Tree', Submodule]]:
296295
return self._iter_convert_to_object(self._cache)
297296

298-
def __len__(self):
297+
def __len__(self) -> int:
299298
return len(self._cache)
300299

301-
def __getitem__(self, item):
300+
def __getitem__(self, item: Union[str, int, slice]) -> Union[Blob, 'Tree', Submodule]:
302301
if isinstance(item, int):
303302
info = self._cache[item]
304303
return self._map_id_to_type[info[1] >> 12](self.repo, info[0], info[1], join_path(self.path, info[2]))
@@ -310,7 +309,7 @@ def __getitem__(self, item):
310309

311310
raise TypeError("Invalid index type: %r" % item)
312311

313-
def __contains__(self, item):
312+
def __contains__(self, item: Union[IndexObject, PathLike]) -> bool:
314313
if isinstance(item, IndexObject):
315314
for info in self._cache:
316315
if item.binsha == info[0]:
@@ -321,10 +320,11 @@ def __contains__(self, item):
321320
# compatibility
322321

323322
# treat item as repo-relative path
324-
path = self.path
325-
for info in self._cache:
326-
if item == join_path(path, info[2]):
327-
return True
323+
else:
324+
path = self.path
325+
for info in self._cache:
326+
if item == join_path(path, info[2]):
327+
return True
328328
# END for each item
329329
return False
330330

‎git/repo/base.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os
88
import re
99
import warnings
10+
from gitdb.db.loose import LooseObjectDB
1011

1112
from gitdb.exc import BadObject
1213

@@ -100,7 +101,7 @@ class Repo(object):
100101
# Subclasses may easily bring in their own custom types by placing a constructor or type here
101102
GitCommandWrapperType = Git
102103

103-
def __init__(self, path: Optional[PathLike] = None, odbt: Type[GitCmdObjectDB] = GitCmdObjectDB,
104+
def __init__(self, path: Optional[PathLike] = None, odbt: Type[LooseObjectDB] = GitCmdObjectDB,
104105
search_parent_directories: bool = False, expand_vars: bool = True) -> None:
105106
"""Create a new Repo instance
106107

0 commit comments

Comments
 (0)