
    Ph]                     P   d dl Z d dlZd dlZd dlmZmZmZmZmZm	Z	m
Z
mZmZ d dlZd dlmZ d dlmZ d dlmZmZ d dlmZ ddlmZmZmZ dd	lmZ  ej8                  e      Z G d
 d      ZdedefdZ  e jB                  d      dedefd       Z" G d dejF                        Z$y)    N)	AnyCallableDictIterableListOptionalSetTupleUnion)Expr)ShapeEnv)FloorDivModularIndexing)bound_sympy   )
sympy_subssympy_symbol	VarRanges)Vc            
           e Zd Zd/ fd	ZdefdZdeeegef   fdZd Z	dededefd	Z
d
eej                     fdZdeeef   defdZdededefdZdee   dee   defdZdededefdZdededefdZdededefdZdededefdZdededdfdZdededdfdZdeeej6                  j8                  j:                  f   defdZdededefdZdedefdZ dee   dee   fdZ!dedefdZ"dddede#e   defd Z$ddd!e%e   de#e   de&ed"f   fd#Z'd/d$Z(d% Z)d&ed'eej                     d(eej                     dee   fd)Z*d&ed'eej                     defd*Z+	 d/d&ed'eej                     d(e#eej                        dee   fd+Z,d&ed'eej                     dee   fd,Z-dedej                  fd-Z.de/ej                     fd.Z0 xZ1S )0SizeVarAllocatorNc                 n   t         |           |
t               }|| _        | j                  j                  | _        | j                  j
                  | _        t               | _        t               | _        | j                         | _
        | j                         | _        | j                         | _        y N)super__init__r   	shape_env
var_to_valreplacementsdictprecomputed_replacementsinv_precomputed_replacementsmake_stride_vars_cachestride_varsmake_simplify_with_ranges_cachesimplify_with_rangesmake_simplify_loops_cache_simplify_loops)selfr   	__class__s     cC:\Users\daisl\Desktop\realtime-object-detection\venv\Lib\site-packages\torch/_inductor/sizevars.pyr   zSizeVarAllocator.__init__   s     
I"..336:nn6Q6Q CG&%FJf)668$($H$H$J!#==?    exprc                 ^    t        j                  |      j                  | j                        S r   )sympyexpandxreplacer   )r(   r,   s     r*   simplifyzSizeVarAllocator.simplify.   s"    ||D!**4+<+<==r+   returnc                 ~     t               t         j                        dt        dt        dt        f fd}|S )R
        self._simplify_with_ranges() can be expensive, cache its results
        r,   
var_rangesr2   c                     t        j                        k7  r%j                          t        j                        | g|j                         }j	                  |d       }|j                  | |      }||<   |S r   )lenr   clearitemsget_simplify_with_ranges)r,   r5   keyresultcachereplacement_countr(   s       r*   r%   zNSizeVarAllocator.make_simplify_with_ranges_cache.<locals>.simplify_with_ranges8   s{     C(9(9$::$'(9(9$:!-***,-CYYsD)F~33D*E#c
Mr+   )r   r7   r   r   r   )r(   r%   r>   r?   s   ` @@r*   r$   z0SizeVarAllocator.make_simplify_with_ranges_cache1   s?     .2V 1 12	t 	 	t 	 $#r+   c                 X     t               t         j                         fd}|S )r4   c                     t        j                        k7  r%j                          t        j                        g | ||}j                  |d       }|j	                  | ||      }||<   |S r   )r7   r   r8   r:   _simplify_loops_impl)
index_varssizesindex_formulasr<   r=   r>   r?   r(   s        r*   simplify_loopszBSizeVarAllocator.make_simplify_loops_cache.<locals>.simplify_loopsN   s{     C(9(9$::$'(9(9$:!8J888CYYsD)F~22:unU#c
Mr+   )r   r7   r   )r(   rF   r>   r?   s   ` @@r*   r&   z*SizeVarAllocator.make_simplify_loops_cacheG   s*     -1F 1 12	 r+   r5   c           	          t         j                  |            }|} fdfd} fd}|j                  t              rV|j	                  t        t        j                  d      t        j                  d      t        j                  d            |      }|j                  t              rB|j	                  t        t        j                  d      t        j                  d            |      }||k7  r j                  |      S |S )zk
        Simplify indexing expression with knowledge of the ranges of
        iteration variables.
        c                 *   | j                   D ]  }|v st        j                  d|g      }| j                  ||z         }|s7|||   j                   vsIt        j                  ||   |      }||k(  shj                  |   |      s~||   }  | S )z)Symbols smaller than the divisor are zero_restexclude)free_symbolsr.   Wildmatchgcdstatically_known_leq)basedivisorvrestmrO   r(   r5   s         r*   remove_zero_termszASizeVarAllocator._simplify_with_ranges.<locals>.remove_zero_termsf   s    &&
? !::gs;D

1t8,AQag&:&::#ii$9'>#88AP'(w ' Kr+   c                 *    t         | |      |      S r   )r   )rQ   rR   rV   s     r*   visit_indexing_divzBSizeVarAllocator._simplify_with_ranges.<locals>.visit_indexing_divu   s    -dG<gFFr+   c                     
| |      } d}t        | t              r| j                  d   dz
  }n| j                  t              syj	                         D ci c]  \  }}|d
 }}}t        | |      }j                  d|      rd}nd}j	                         D ci c]  \  }}||dz
   }	}}t        | |	      }n| }j                  |||z        r|rt        | |      S t        | ||      S c c}}w c c}}w )NT   r   r   F)	
isinstancer   argshasr9   r   rP   statically_known_ltr   )rQ   rR   modulusbase_posbase_skrS   iter_ranges_zerobase_lowestiter_rangesrV   r(   r5   s             r*   visit_modular_indexingzFSizeVarAllocator._simplify_with_ranges.<locals>.visit_modular_indexingx   s   $T73DH$0 1)XXo.5?5E5E5G#H5GTQAqD5G #H(/?@,,Q<#H$H4>4D4D4FG4FDAqq!a%x4FG#D+6'''0ABxg.."4':: $I Hs   C0C6rQ   rR   r_   )	join_dimensionsr1   r]   r   replacer.   rM   r   r;   )r(   r,   r5   original_exprrX   rf   rV   s   ` `   @r*   r;   z&SizeVarAllocator._simplify_with_ranges]   s     t}}T23		G	;0 88O$<<JJv&JJy)JJy)
 'D 88H<<JJv&JJy) #D = --dJ??r+   rC   c           
          t        t         j                              D cg c]  } j                  |       c}t	              t	        d         k(  sJ t	              t	        d         f       t        t	                    D ]  }|   dk(  sd|<     fd}d}|rd}t        j                  t        t        t	                          t        t        t	                                D ]4  \  }}||k(  s
|   |    |||      s d}|   |   z  |<   d|<   6 |rfd}	fd}
D cg c]  }||	 c}|	|
fS c c}w c c}w )	a  
        Try to remove as many axis from loop iterations as possible, by:
            1) removing size==1 dimensions
            2) fuse contiguous dimensions into a single loop
            If channel_last = True, we will prevent the last dim fused with other dims
        r   r   Nc                 n   t        t                    D ]  }
j                  |   |    |    z        
j                  |   |         k(  re	|    }	|   }t        d      }t	        |   |||    z  |di      }t	        |   |d||i      }
j                  |      
j                  |      k(  r y y)N_merge_testerr   FT)ranger7   r1   r   r   )abrb   vavbrS   expr1expr2rE   rC   r(   rD   stridess           r*   can_merge_dimsz=SizeVarAllocator._simplify_loops_impl.<locals>.can_merge_dims   s    3w<(==Aq!9:dmmAJqM?  $AB#AB$_5A&~a'82q58|RQR:STE&~a'82q"a.IE}}U+t}}U/CC  ) r+   TFc                     t        t        |             }g }D ]H  }|%|j                  t        j                  d             *|j                  |j                                J |rJ |S Nr   )listreversedappendr.   Integerpop)indexit	new_indexsizerD   s       r*   reindexz6SizeVarAllocator._simplify_loops_impl.<locals>.reindex   s_    huo&BI<$$U]]1%56$$RVVX.	 
 M6r+   c                     t        |       t              k(  sJ t        |       D cg c]
  \  }}|	| c}}S c c}}w r   )r7   zip)r}   isrD   s      r*   prunez4SizeVarAllocator._simplify_loops_impl.<locals>.prune   sB    u:U+++"%eU"3E"3$!Qq}A"3EEEs   
==)	rx   mapr1   r#   r7   rm   	itertoolsproductry   )r(   rC   rD   rE   xr   ru   changedjr   r   rt   s   ````       @r*   rB   z%SizeVarAllocator._simplify_loops_impl   sp    S./<JKNq4##Az2NK5zS_,Ks5z3wqz?.KK,s5z"AQx1}a #
	 	  G!))s5z*+XeCJ6G-H1 6U1X-q1A!!Q'"G$Qx%(2E!H#E!H 			F !25aAM52GUBBi Lh 3s   EEEc                     |dv rt        |      S 	 | j                  j                  |      }|t        |      S 	 y# t        $ r t        j                  d|       Y yw xY w)N)TFzCould not simplify %sF)boolr   _maybe_evaluate_static	Exceptionlogdebug)r(   r,   
simplifieds      r*   is_expr_static_and_truez(SizeVarAllocator.is_expr_static_and_true  sk    = :	5>>tDJ%J'' &
   	5II-t4	5s   '; AAleftrightc                 L    | j                  t        j                  ||            S )zf
        Returns a bool indicating if it is sound to optimize as if left and right are equal.
        )r   r.   Eqr(   r   r   s      r*   statically_known_equalsz(SizeVarAllocator.statically_known_equals  s      ++EHHT5,ABBr+   c                 t     t        |      t        |      k7  ryt         fdt        ||      D              ryy)zl
        Returns a bool indicating if it is sound to optimize as if left and right lists are equal.
        Fc              3   H   K   | ]  \  }}j                  ||        y wr   )r   ).0lrr(   s      r*   	<genexpr>z@SizeVarAllocator.statically_known_list_equals.<locals>.<genexpr>  s%     O>Ndat++Aq1>Ns   "T)r7   allr   r   s   `  r*   statically_known_list_equalsz-SizeVarAllocator.statically_known_list_equals  s3     t9E
"Oc$>NOOr+   c                 .    ||k  }| j                  |      S )zq
        Returns a bool indicating if it is sound to optimize as if left is less than or equal to right.
        r   r(   r   r   r,   s       r*   rP   z%SizeVarAllocator.statically_known_leq!  s     u}++D11r+   c                 .    ||k  }| j                  |      S )ze
        Returns a bool indicating if it is sound to optimize as if left is less than right.
        r   r   s       r*   r^   z$SizeVarAllocator.statically_known_lt)  s     e|++D11r+   	numeratordenominatorc                 V    t        j                  ||z  d      }| j                  |      S )z|
        Return a bool indicating if it is sound to optimize for the numerator being a multiple of the denominator.
        r   )r.   r   r   )r(   r   r   r,   s       r*   statically_known_multiple_ofz-SizeVarAllocator.statically_known_multiple_of1  s*     xx	K/3++D11r+   c                     t        |t              rt        || j                        }t        |t              rt        || j                        }| j                  j                  t        j                  ||            sJ |S r   )r[   r   r   r!   r   evaluate_exprr.   r   r   s      r*   guard_equalszSizeVarAllocator.guard_equals=  sb    dD!dD$E$EFDeT"ud&G&GHE~~++EHHT5,ABBBr+   c                 ,    | j                  ||dz         S )Nr   )guard_ltr   s      r*   	guard_leqzSizeVarAllocator.guard_leqE  s    }}T519--r+   c                 f    | j                   j                  t        j                  ||            sJ y r   )r   r   r.   Ltr   s      r*   r   zSizeVarAllocator.guard_ltH  s%    ~~++EHHT5,ABBBr+   c                     t        |t        t        j                  j                  j
                  f      sJ t        |             | j                  j                  t        j                  |            S r   )
r[   r   r.   logicboolalgBooleantyper   r   sympify)r(   r   s     r*   r   zSizeVarAllocator.evaluate_exprS  sN    $u{{':':'B'B CDPd4jPD~~++EMM$,?@@r+   c                     | j                  |      }| j                  |      }||k  r| j                  ||       |S | j                  ||       |S )z>return the smaller of left and right, and guard on that choice)	size_hintr   )r(   r   r   lvrvs        r*   evaluate_minzSizeVarAllocator.evaluate_minW  sK    ^^D!^^E"8NN4'KNN5$'Lr+   c                     | j                  |      }| j                  |t        j                  |             t	        |      S r   )r   r   r.   r{   intr   s      r*   evaluate_static_shapez&SizeVarAllocator.evaluate_static_shapeb  s3    t$$e 455zr+   c                 J    |D cg c]  }| j                  |       c}S c c}w r   )r   )r(   r   r   s      r*   evaluate_static_shapesz'SizeVarAllocator.evaluate_static_shapesg  s%    7;<t!**1-t<<<s    c                 6   t        |t              st        |t              sJ |S |j                  }|st        |      S t	        d |D              r5t        || j                        }|j                  }t	        d |D              r5t        || j                        S )Nc              3   R   K   | ]  }|j                   j                  d        ! yw)psN)name
startswith)r   r   s     r*   r   z1SizeVarAllocator.symbolic_hint.<locals>.<genexpr>r  s      @<a!&&##D)<s   %')r[   r   r   rL   anyr   r!   r   )r(   r,   rL   s      r*   symbolic_hintzSizeVarAllocator.symbolic_hintj  s    $%dC(((K((t9@<@@dD$E$EFD,,L @<@@ $00r+   fallbackr   c                B   | j                  |      }t        |t        t        j                  f      s||j
                  D ci c])  }|| j                  j                  j                  |d       + }}t        d |j                         D              rXt        ||      }| j                  |j                        }| j                  |j                        }t        t!        ||      |      }|S 	 t        |      S c c}w # t"        $ r t$        j'                  d|        w xY w)Nc              3   $   K   | ]  }|d u 
 y wr    )r   vrs     r*   r   z-SizeVarAllocator.size_hint.<locals>.<genexpr>~  s     =,<b2T>,<s   zfailed on: %s)r   r[   r   r.   r{   rL   r   var_to_ranger:   r   valuesr   r   lowerupperminmaxr   r   r   )	r(   r,   r   outr   sym_vrsexpr_vrr   r   s	            r*   r   zSizeVarAllocator.size_hintw  s     &#U]]349M FJEVEVEV4>>..221d;;EV   =GNN,<==%dG4w}}5w}}5s8U3U;O	s8O  	IIos+	s   .C8-
C= =!Dexprs.c                0     t         fd|D              S )Nc              3   D   K   | ]  }j                  |         yw)r   N)r   )r   r   r   r(   s     r*   r   z.SizeVarAllocator.size_hints.<locals>.<genexpr>  s     I5aT^^A^95s    tuple)r(   r   r   s   ` `r*   
size_hintszSizeVarAllocator.size_hints  s     I5IIIr+   c                       t        j                  |      |      t         j                        t        j                  |       fd       }|S )zp
        Wrapper around functools.lru_cache that clears when replacements
        has been invalidated.
        c                      t        j                        k7  r%t        j                        j                           | i |S r   )r7   r   cache_clear)r\   kwargsfn_cache	prior_lenr(   s     r*   wrapperz,SizeVarAllocator._lru_cache.<locals>.wrapper  sD     C 1 122 1 12	$$&T,V,,r+   )	functools	lru_cacher7   r   wraps)r(   fnmaxsizer   r   r   s   `   @@r*   
_lru_cachezSizeVarAllocator._lru_cache  sN    
 09&&w/3))*				- 
	- r+   c           
          | j                  | j                        	 ddt        dt        t        j
                     dt        t        t        j
                        dt        t           ffd}|S )Nr}   varssupport_varsr2   c                 D    |s|} | t        |      t        |            S r   r   )r}   r   r   r>   s      r*   r#   z<SizeVarAllocator.make_stride_vars_cache.<locals>.stride_vars  s%    
  #dU<-@AAr+   r   )r   _stride_varsr   r   r.   Symbolr   )r(   r#   r>   s     @r*   r"   z'SizeVarAllocator.make_stride_vars_cache  sm     1 12
 :>	B	Bu||$	B #4#56	B $Z		B r+   r}   r   r   c                    g }| j                  |      }|t        ||D ci c]  }|dk7  s	|t        j                  d        c}      z
  }t	        t        |            D ]  }t        |t	        t        |            D ci c].  }||   ||   k7  r!||   dk7  r||   t        j                  d      0 c}      }||   }|dk(  r%|j                  t        j                  d             |j                  t        ||t        j                  d      i      t        ||t        j                  d      i      z
          |S c c}w c c}w )a  Convert an indexing expression back into strides

        NOTE: This is only valid if the index is a standard strided offset
        calculation. e.g. 10 * ModularIndexing(i0 + 1, 1, 2) would give a
        stride of -10 because the index wraps around after the first element

        r   r   )r1   r   r.   r{   rm   r7   rz   )	r(   r}   r   r   rt   rS   r   r   	index_dims	            r*   r   zSizeVarAllocator._stride_vars  sC    e$
HAaAu}}Q''H
 
 s4y!A" #3|#455Aw,q/1l1o6J !OU]]1%555I QAAvu}}Q/0 y1emmA.>*?@ Qa0@,ABC "& + Is   
D=D=>3Ec           
          | j                  |      }t        ||D ci c]  }|dk7  s	|t        j                  d        c}      S c c}w )z-Extract offset part of an indexing expressionr   )r1   r   r.   r{   )r(   r}   r   rS   s       r*   
offset_varzSizeVarAllocator.offset_var  sC    e$%t!Nt!qAv!U]]1%5"5t!NOO!Ns
   
A
A
c                 2   |j                   D ],  }|j                  j                  d      st        ||di      }. g }| j	                  |||      D ]#  }	 |j                  | j                  |             % |S # t        $ r |j                  d       Y Dw xY w)Nindirectr   )rL   r   r   r   r#   rz   r   	TypeError)r(   r}   r   r   rS   r=   r   s          r*   stride_hintszSizeVarAllocator.stride_hints  s     ##Avv  ,"51a&1 $ !!%|<A!dnnQ/0 =
   !a !s    A99BBc           	          t        t        t        | j                  ||                  t	        t        t                          }|j                  fd       |S )Nc                     |    dk(  |    fS rw   r   )r   rt   s    r*   <lambda>z/SizeVarAllocator.stride_order.<locals>.<lambda>  s    '!*/71:!>r+   )r<   )r   r   absr   rx   rm   r7   sort)r(   r}   r   orderrt   s       @r*   stride_orderzSizeVarAllocator.stride_order  sH    C!2!25$!?@AU3w<()

>
?r+   c                     || j                   vr?t        dt        | j                                }|| j                   |<   || j                  |<   | j                   |   S )Nr   )r    r   r7   r!   )r(   r,   syms      r*   lookup_precomputed_sizez(SizeVarAllocator.lookup_precomputed_size  s^    t444C(E(E$F#GHIC25D))$/59D--c2,,T22r+   c                     t        | j                  j                               t        | j                  j                               z
  S r   )setr   keysr   )r(   s    r*   rL   zSizeVarAllocator.free_symbols  s3    4??'')*S1B1B1G1G1I-JJJr+   r   )2__name__
__module____qualname__r   r   r1   r   r   r$   r&   r;   r   r.   r   rB   r   r   r   r   r   r   rP   r^   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r
   r   r   r"   r   r   r   r   r  r	   rL   __classcell__r)   s   @r*   r   r      sb   @.>T >$4:KT:Q1R $,,H$ HI H$ HT?Cu||,?CvE$),<  CD C C$ Cd DJ SW 2 2d 2t 22 2T 2d 22d 2 2RV 2 d t .d .4 .D .CT C$ C4 CA%ekk.A.A.I.I(I"J At A	 	d 	t 	$ 3 
=4: =$s) =1$ 14 1 BF d # # . #'	J~J 3-	J
 
sCxJ$##!%ell!3#CGCU#	d#JP PD,> P4 P 6:	 5<<  tELL12	
 
c"$ d5<<.@ T#Y 3D 3U\\ 3Kc%,,/ Kr+   r   r,   r2   c                 z    t        | t        j                        r| j                  t              s| S t        |       S r   )r[   r.   Addr]   r   _join_dimensions_cached)r,   s    r*   rg   rg     s+    dEII&dhh.G"4((r+      c                    t        | t        j                        sJ t        j                  ddg      }t        j                  d      }t        j                  d      }t        j                  d      }t        j                  d      }| j                  D ]  }|j                  |t        |||      z        }|s%| j                  D ]x  }|j                  ||   ||   z  t        ||   ||   ||   z  |      z        }	|	s:||k7  s@t        | |z
  |z
  ||   t        ||   ||   ||   |	|   z        z  z         } | c c S   | j                  D ]  }|j                  |t        |||      z        }|s%| j                  D ]g  }|j                  ||   ||   z  t        ||   ||   ||   z        z        }	|	9t        | |z
  |z
  ||   t        ||   ||         z  z         } | c c S   | S )z
    ModularIndexing(i0, 1, 32) + 32 * ModularIndexing(i0, 32, 4)
    becomes
    ModularIndexing(i0, 1, 128)
    ModularIndexing(i0, 1, 32) + 32 * FloorDiv(i0, 32)
    becomes i0


    This type of pattern can come from view operations
    scaler   rJ   rQ   rR   r_   modulus2)	r[   r.   r  rM   r\   rN   r   rg   r   )
r,   r  rQ   rR   mod1mod2term1m1term2m2s
             r*   r  r    s    dEII&&&JJw,E::fDjj#G::i D::j!D[[w!EEF[[uIh%bh7bh0FMN
 %5.*   U))"T(BwKDBtHATUVVD  K # $ [[w!EEF[[uI4(8BtHbkBtH>T+UU >*   U)hr$xG&EEFD  K #  Kr+   c                   `     e Zd ZdZdef fdZdedej                  fdZ	d
dZ
d Zd	 Z xZS )SimplifyIndexingzt
    A wrapper around .virtualize.ops that uses var range information to
    simplify ModularIndexing/FloorDiv.
    r5   c                 H    t         |   |       d| _        fd| _        y )Nr  c                 X    t         j                  j                  j                  |       S r   )r   graphsizevarsr%   )r}   r5   s    r*   r   z+SimplifyIndexing.__init__.<locals>.<lambda>F  s    !''**??zRr+   )r   r   r   	_simplify)r(   innerr5   r)   s     `r*   r   zSimplifyIndexing.__init__A  s$    &	 S 	r+   r   r}   c                 X    | j                   j                  || j                  |            S r   )_innerloadr  )r(   r   r}   s      r*   r#  zSimplifyIndexing.loadH  s"    {{dnnU&;<<r+   c                 ^    | j                   j                  || j                  |      ||      S )N)mode)r"  storer  )r(   r   r}   valuer%  s        r*   r&  zSimplifyIndexing.storeK  s)    {{  t~~e'<e$ OOr+   c                 Z    | j                   j                  || j                  |      |      S r   )r"  store_reductionr  )r(   r   r}   r'  s       r*   r)  z SimplifyIndexing.store_reductionN  s$    {{**41FNNr+   c                 X    | j                   j                  | j                  |      |      S r   )r"  
index_exprr  )r(   r}   dtypes      r*   r+  zSimplifyIndexing.index_exprQ  s"    {{%%dnnU&;UCCr+   r   )r  r  r	  __doc__r   r   strr.   r   r#  r&  r)  r+  r
  r  s   @r*   r  r  ;  s?    
S) S= =UZZ =PODr+   r  )%r   r   loggingtypingr   r   r   r   r   r   r	   r
   r   r.   r   %torch.fx.experimental.symbolic_shapesr   torch.utils._sympy.functionsr   r   torch.utils._sympy.value_rangesr   utilsr   r   r   virtualizedr   	getLoggerr  r   r   rg   r   r  WrapperHandlerr  r   r+   r*   <module>r8     s       S S S   : B 7 6 6 g!eK eKP)$ )4 ) S3$ 34 3 3lDq'' Dr+   