@@ -36,7 +36,7 @@ class Color:
3636 CYAN = "#06B6D4"
3737 WHITE = "white"
3838 BLACK = "black"
39- LIGHT_GRAY = "#E5E7EB "
39+ LIGHT_GRAY = "#D1D5DB "
4040
4141
4242ELEMENT_COLOR : Dict [type , str ] = {
@@ -125,19 +125,16 @@ def draw_sub_lattices(
125125 lattice : Lattice ,
126126 * ,
127127 labels : bool = True ,
128- location : str = "bottom " ,
128+ location : str = "top " ,
129129):
130130 x_min , x_max = ax .get_xlim ()
131131 length_gen = [0.0 , * (obj .length for obj in lattice .children )]
132132 position_list = np .add .accumulate (length_gen )
133133 i_min = np .searchsorted (position_list , x_min )
134- i_max = np .searchsorted (position_list , x_max )
135- ticks = position_list [i_min : i_max + 1 ]
134+ i_max = np .searchsorted (position_list , x_max , side = "right" )
135+ ticks = position_list [i_min : i_max ]
136136 ax .set_xticks (ticks )
137- # if len(ticks) < 5:
138- # ax.xaxis.set_minor_locator(AutoMinorLocator())
139- # ax.xaxis.set_minor_formatter(ScalarFormatter())
140- ax .grid (axis = "x" , color = Color .LIGHT_GRAY , linestyle = "--" , linewidth = 1 )
137+ ax .grid (color = Color .LIGHT_GRAY , linestyle = "--" , linewidth = 1 )
141138
142139 if labels :
143140 y_min , y_max = ax .get_ylim ()
@@ -154,13 +151,12 @@ def draw_sub_lattices(
154151 if not isinstance (obj , Lattice ) or start >= x_max or end <= x_min :
155152 continue
156153
157- x0 = end - obj . length / 2
154+ x0 = ( max ( start , x_min ) + min ( end , x_max )) / 2
158155 ax .annotate (
159156 obj .name ,
160157 xy = (x0 , y0 ),
161158 fontsize = FONT_SIZE + 2 ,
162159 fontstyle = "oblique" ,
163- alpha = 0.5 ,
164160 va = "center" ,
165161 ha = "center" ,
166162 clip_on = True ,
@@ -229,16 +225,16 @@ def _twiss_plot_section(
229225 annotate_lattices = True ,
230226 line_style = "solid" ,
231227 line_width = 1.3 ,
232- ref_twiss = None ,
228+ twiss_ref = None ,
233229 scales = {"eta_x" : 10 },
234230 overwrite = False ,
235231):
236232 if overwrite :
237233 ax .clear ()
238- if ref_twiss :
234+ if twiss_ref :
239235 plot_twiss (
240236 ax ,
241- ref_twiss ,
237+ twiss_ref ,
242238 line_style = "dashed" ,
243239 line_width = 2.5 ,
244240 alpha = 0.5 ,
@@ -254,8 +250,8 @@ def _twiss_plot_section(
254250
255251 ax .set_xlim ((x_min , x_max ))
256252 ax .set_ylim ((y_min , y_max ))
257- draw_elements (ax , twiss .lattice , labels = annotate_elements )
258253 draw_sub_lattices (ax , twiss .lattice , labels = annotate_lattices )
254+ draw_elements (ax , twiss .lattice , labels = annotate_elements )
259255
260256
261257# TODO:
@@ -272,7 +268,7 @@ class TwissPlot:
272268 :param y_min float: Minimum y-limit
273269 :param main bool: Wheter to plot whole ring or only given sections
274270 :param scales Dict[str, int]: Optional scaling factors for optical functions
275- :param Twiss ref_twiss : Reference twiss values. Will be plotted as dashed lines.
271+ :param Twiss twiss_ref : Reference twiss values. Will be plotted as dashed lines.
276272 :param pairs: List of (element, attribute)-pairs to create interactice sliders for.
277273 :type pairs: List[Tuple[Element, str]]
278274 """
@@ -287,7 +283,8 @@ def __init__(
287283 y_max = None ,
288284 main = True ,
289285 scales = {"eta_x" : 10 },
290- ref_twiss = None ,
286+ twiss_ref = None ,
287+ title = None ,
291288 pairs : Optional [List [Tuple [Element , str ]]] = None ,
292289 ):
293290 self .fig = plt .figure ()
@@ -300,6 +297,7 @@ def __init__(
300297 len (height_ratios ), 1 , self .fig , height_ratios = height_ratios
301298 )
302299 self .axs_sections = [] # TODO: needed for update function
300+ self .title = self .lattice .name if title is None else title
303301
304302 if pairs :
305303 _ , axs = plt .subplots (nrows = len (pairs ))
@@ -323,7 +321,7 @@ def __init__(
323321 _twiss_plot_section (
324322 self .ax_main ,
325323 self .twiss ,
326- ref_twiss = ref_twiss ,
324+ twiss_ref = twiss_ref ,
327325 y_min = y_min ,
328326 y_max = y_max ,
329327 annotate_elements = False ,
@@ -349,7 +347,7 @@ def __init__(
349347 _twiss_plot_section (
350348 self .axs_sections [i ],
351349 self .twiss ,
352- ref_twiss = ref_twiss ,
350+ twiss_ref = twiss_ref ,
353351 x_min = x_min ,
354352 x_max = x_max ,
355353 y_min = y_min ,
@@ -359,12 +357,11 @@ def __init__(
359357 )
360358
361359 handles , labels = self .fig .axes [0 ].get_legend_handles_labels ()
360+ if twiss_ref :
361+ handles = handles [len (twiss_functions ) :]
362+ labels = labels [len (twiss_functions ) :]
362363 self .fig .legend (handles , labels , loc = "upper left" , ncol = 10 , frameon = False )
363- title = self .lattice .name
364- if self .lattice .info != "" :
365- title += f"({ self .lattice .info } )"
366-
367- self .fig .suptitle (title , ha = "right" , x = 0.98 )
364+ self .fig .suptitle (self .title , ha = "right" , x = 0.98 )
368365 self .fig .tight_layout ()
369366 self .fig .subplots_adjust (top = 0.9 )
370367
@@ -395,7 +392,6 @@ def floor_plan(
395392 * ,
396393 start_angle : float = 0 ,
397394 labels : bool = True ,
398- direction : str = "clockwise" ,
399395):
400396 ax .set_aspect ("equal" )
401397 codes = Path .MOVETO , Path .LINETO
0 commit comments