Skip to content

Commit 237966a

Browse files
committed
Type Tree.traverse() better
1 parent b7fe37a commit 237966a

File tree

5 files changed

+28
-23
lines changed

5 files changed

+28
-23
lines changed

‎git/objects/base.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@
2222
if TYPE_CHECKING:
2323
from git.repo import Repo
2424
from gitdb.base import OStream
25-
# from .tree import Tree, Blob, Commit, TagObject
25+
from .tree import Tree
26+
from .blob import Blob
27+
from .submodule.base import Submodule
28+
29+
IndexObjUnion = Union['Tree', 'Blob', 'Submodule']
2630

2731
# --------------------------------------------------------------------------
2832

‎git/objects/output.txt

Whitespace-only changes.

‎git/objects/tree.py

+18-19
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from git.util import to_bin_sha
1010

1111
from . import util
12-
from .base import IndexObject
12+
from .base import IndexObject, IndexObjUnion
1313
from .blob import Blob
1414
from .submodule.base import Submodule
1515

@@ -28,10 +28,11 @@
2828

2929
if TYPE_CHECKING:
3030
from git.repo import Repo
31-
from git.objects.util import TraversedTup
3231
from io import BytesIO
3332

34-
T_Tree_cache = TypeVar('T_Tree_cache', bound=Union[Tuple[bytes, int, str]])
33+
T_Tree_cache = TypeVar('T_Tree_cache', bound=Tuple[bytes, int, str])
34+
TraversedTreeTup = Union[Tuple[Union['Tree', None], IndexObjUnion,
35+
Tuple['Submodule', 'Submodule']]]
3536

3637
#--------------------------------------------------------
3738

@@ -201,7 +202,7 @@ class Tree(IndexObject, diff.Diffable, util.Traversable, util.Serializable):
201202
symlink_id = 0o12
202203
tree_id = 0o04
203204

204-
_map_id_to_type: Dict[int, Union[Type[Submodule], Type[Blob], Type['Tree']]] = {
205+
_map_id_to_type: Dict[int, Type[IndexObjUnion]] = {
205206
commit_id: Submodule,
206207
blob_id: Blob,
207208
symlink_id: Blob
@@ -229,7 +230,7 @@ def _set_cache_(self, attr: str) -> None:
229230
# END handle attribute
230231

231232
def _iter_convert_to_object(self, iterable: Iterable[Tuple[bytes, int, str]]
232-
) -> Iterator[Union[Blob, 'Tree', Submodule]]:
233+
) -> Iterator[IndexObjUnion]:
233234
"""Iterable yields tuples of (binsha, mode, name), which will be converted
234235
to the respective object representation"""
235236
for binsha, mode, name in iterable:
@@ -240,7 +241,7 @@ def _iter_convert_to_object(self, iterable: Iterable[Tuple[bytes, int, str]]
240241
raise TypeError("Unknown mode %o found in tree data for path '%s'" % (mode, path)) from e
241242
# END for each item
242243

243-
def join(self, file: str) -> Union[Blob, 'Tree', Submodule]:
244+
def join(self, file: str) -> IndexObjUnion:
244245
"""Find the named object in this tree's contents
245246
:return: ``git.Blob`` or ``git.Tree`` or ``git.Submodule``
246247
@@ -273,7 +274,7 @@ def join(self, file: str) -> Union[Blob, 'Tree', Submodule]:
273274
raise KeyError(msg % file)
274275
# END handle long paths
275276

276-
def __truediv__(self, file: str) -> Union['Tree', Blob, Submodule]:
277+
def __truediv__(self, file: str) -> IndexObjUnion:
277278
"""For PY3 only"""
278279
return self.join(file)
279280

@@ -296,17 +297,16 @@ def cache(self) -> TreeModifier:
296297
See the ``TreeModifier`` for more information on how to alter the cache"""
297298
return TreeModifier(self._cache)
298299

299-
def traverse(self,
300-
predicate: Callable[[Union['Tree', 'Submodule', 'Blob',
301-
'TraversedTup'], int], bool] = lambda i, d: True,
302-
prune: Callable[[Union['Tree', 'Submodule', 'Blob', 'TraversedTup'], int], bool] = lambda i, d: False,
300+
def traverse(self, # type: ignore # overrides super()
301+
predicate: Callable[[Union[IndexObjUnion, TraversedTreeTup], int], bool] = lambda i, d: True,
302+
prune: Callable[[Union[IndexObjUnion, TraversedTreeTup], int], bool] = lambda i, d: False,
303303
depth: int = -1,
304304
branch_first: bool = True,
305305
visit_once: bool = False,
306306
ignore_self: int = 1,
307307
as_edge: bool = False
308-
) -> Union[Iterator[Union['Tree', 'Blob', 'Submodule']],
309-
Iterator[Tuple[Union['Tree', 'Submodule', None], Union['Tree', 'Blob', 'Submodule']]]]:
308+
) -> Union[Iterator[IndexObjUnion],
309+
Iterator[TraversedTreeTup]]:
310310
"""For documentation, see util.Traversable._traverse()
311311
Trees are set to visit_once = False to gain more performance in the traversal"""
312312

@@ -320,23 +320,22 @@ def traverse(self,
320320
# ret_tup = itertools.tee(ret, 2)
321321
# assert is_tree_traversed(ret_tup), f"Type is {[type(x) for x in list(ret_tup[0])]}"
322322
# return ret_tup[0]"""
323-
return cast(Union[Iterator[Union['Tree', 'Blob', 'Submodule']],
324-
Iterator[Tuple[Union['Tree', 'Submodule', None], Union['Tree', 'Blob', 'Submodule']]]],
323+
return cast(Union[Iterator[IndexObjUnion], Iterator[TraversedTreeTup]],
325324
super(Tree, self).traverse(predicate, prune, depth, # type: ignore
326325
branch_first, visit_once, ignore_self))
327326

328327
# List protocol
329328

330-
def __getslice__(self, i: int, j: int) -> List[Union[Blob, 'Tree', Submodule]]:
329+
def __getslice__(self, i: int, j: int) -> List[IndexObjUnion]:
331330
return list(self._iter_convert_to_object(self._cache[i:j]))
332331

333-
def __iter__(self) -> Iterator[Union[Blob, 'Tree', Submodule]]:
332+
def __iter__(self) -> Iterator[IndexObjUnion]:
334333
return self._iter_convert_to_object(self._cache)
335334

336335
def __len__(self) -> int:
337336
return len(self._cache)
338337

339-
def __getitem__(self, item: Union[str, int, slice]) -> Union[Blob, 'Tree', Submodule]:
338+
def __getitem__(self, item: Union[str, int, slice]) -> IndexObjUnion:
340339
if isinstance(item, int):
341340
info = self._cache[item]
342341
return self._map_id_to_type[info[1] >> 12](self.repo, info[0], info[1], join_path(self.path, info[2]))
@@ -348,7 +347,7 @@ def __getitem__(self, item: Union[str, int, slice]) -> Union[Blob, 'Tree', Submo
348347

349348
raise TypeError("Invalid index type: %r" % item)
350349

351-
def __contains__(self, item: Union[IndexObject, PathLike]) -> bool:
350+
def __contains__(self, item: Union[IndexObjUnion, PathLike]) -> bool:
352351
if isinstance(item, IndexObject):
353352
for info in self._cache:
354353
if item.binsha == info[0]:

‎git/objects/util.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,14 @@
3030
from .commit import Commit
3131
from .blob import Blob
3232
from .tag import TagObject
33-
from .tree import Tree
33+
from .tree import Tree, TraversedTreeTup
3434
from subprocess import Popen
3535

3636

3737
T_TIobj = TypeVar('T_TIobj', bound='TraversableIterableObj') # for TraversableIterableObj.traverse()
38-
TraversedTup = Tuple[Union['Traversable', None], Union['Traversable', 'Blob']] # for Traversable.traverse()
38+
39+
TraversedTup = Union[Tuple[Union['Traversable', None], 'Traversable'], # for commit, submodule
40+
TraversedTreeTup] # for tree.traverse()
3941

4042
# --------------------------------------------------------------------
4143

‎git/util.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1072,7 +1072,7 @@ class IterableObj():
10721072
Subclasses = [Submodule, Commit, Reference, PushInfo, FetchInfo, Remote]"""
10731073

10741074
__slots__ = ()
1075-
_id_attribute_ = "attribute that most suitably identifies your instance"
1075+
_id_attribute_: str
10761076

10771077
@classmethod
10781078
def list_items(cls, repo: 'Repo', *args: Any, **kwargs: Any) -> IterableList[T_IterableObj]:

0 commit comments

Comments
 (0)