• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1from Errors import CompileError, error
2import ExprNodes
3from ExprNodes import IntNode, NameNode, AttributeNode
4import Options
5from Code import UtilityCode, TempitaUtilityCode
6from UtilityCode import CythonUtilityCode
7import Buffer
8import PyrexTypes
9import ModuleNode
10
11START_ERR = "Start must not be given."
12STOP_ERR = "Axis specification only allowed in the 'step' slot."
13STEP_ERR = "Step must be omitted, 1, or a valid specifier."
14BOTH_CF_ERR = "Cannot specify an array that is both C and Fortran contiguous."
15INVALID_ERR = "Invalid axis specification."
16NOT_CIMPORTED_ERR = "Variable was not cimported from cython.view"
17EXPR_ERR = "no expressions allowed in axis spec, only names and literals."
18CF_ERR = "Invalid axis specification for a C/Fortran contiguous array."
19ERR_UNINITIALIZED = ("Cannot check if memoryview %s is initialized without the "
20                     "GIL, consider using initializedcheck(False)")
21
22def err_if_nogil_initialized_check(pos, env, name='variable'):
23    "This raises an exception at runtime now"
24    pass
25    #if env.nogil and env.directives['initializedcheck']:
26        #error(pos, ERR_UNINITIALIZED % name)
27
28def concat_flags(*flags):
29    return "(%s)" % "|".join(flags)
30
31format_flag = "PyBUF_FORMAT"
32
33memview_c_contiguous = "(PyBUF_C_CONTIGUOUS | PyBUF_FORMAT | PyBUF_WRITABLE)"
34memview_f_contiguous = "(PyBUF_F_CONTIGUOUS | PyBUF_FORMAT | PyBUF_WRITABLE)"
35memview_any_contiguous = "(PyBUF_ANY_CONTIGUOUS | PyBUF_FORMAT | PyBUF_WRITABLE)"
36memview_full_access = "PyBUF_FULL"
37#memview_strided_access = "PyBUF_STRIDED"
38memview_strided_access = "PyBUF_RECORDS"
39
40MEMVIEW_DIRECT = '__Pyx_MEMVIEW_DIRECT'
41MEMVIEW_PTR    = '__Pyx_MEMVIEW_PTR'
42MEMVIEW_FULL   = '__Pyx_MEMVIEW_FULL'
43MEMVIEW_CONTIG = '__Pyx_MEMVIEW_CONTIG'
44MEMVIEW_STRIDED= '__Pyx_MEMVIEW_STRIDED'
45MEMVIEW_FOLLOW = '__Pyx_MEMVIEW_FOLLOW'
46
47_spec_to_const = {
48        'direct' : MEMVIEW_DIRECT,
49        'ptr'    : MEMVIEW_PTR,
50        'full'   : MEMVIEW_FULL,
51        'contig' : MEMVIEW_CONTIG,
52        'strided': MEMVIEW_STRIDED,
53        'follow' : MEMVIEW_FOLLOW,
54        }
55
56_spec_to_abbrev = {
57    'direct'  : 'd',
58    'ptr'     : 'p',
59    'full'    : 'f',
60    'contig'  : 'c',
61    'strided' : 's',
62    'follow'  : '_',
63}
64
65memslice_entry_init = "{ 0, 0, { 0 }, { 0 }, { 0 } }"
66
67memview_name = u'memoryview'
68memview_typeptr_cname = '__pyx_memoryview_type'
69memview_objstruct_cname = '__pyx_memoryview_obj'
70memviewslice_cname = u'__Pyx_memviewslice'
71
72def put_init_entry(mv_cname, code):
73    code.putln("%s.data = NULL;" % mv_cname)
74    code.putln("%s.memview = NULL;" % mv_cname)
75
76def mangle_dtype_name(dtype):
77    # a dumb wrapper for now; move Buffer.mangle_dtype_name in here later?
78    import Buffer
79    return Buffer.mangle_dtype_name(dtype)
80
81#def axes_to_str(axes):
82#    return "".join([access[0].upper()+packing[0] for (access, packing) in axes])
83
84def put_acquire_memoryviewslice(lhs_cname, lhs_type, lhs_pos, rhs, code,
85                                have_gil=False, first_assignment=True):
86    "We can avoid decreffing the lhs if we know it is the first assignment"
87    assert rhs.type.is_memoryviewslice
88
89    pretty_rhs = rhs.result_in_temp() or rhs.is_simple()
90    if pretty_rhs:
91        rhstmp = rhs.result()
92    else:
93        rhstmp = code.funcstate.allocate_temp(lhs_type, manage_ref=False)
94        code.putln("%s = %s;" % (rhstmp, rhs.result_as(lhs_type)))
95
96    # Allow uninitialized assignment
97    #code.putln(code.put_error_if_unbound(lhs_pos, rhs.entry))
98    put_assign_to_memviewslice(lhs_cname, rhs, rhstmp, lhs_type, code,
99                               have_gil=have_gil, first_assignment=first_assignment)
100
101    if not pretty_rhs:
102        code.funcstate.release_temp(rhstmp)
103
104def put_assign_to_memviewslice(lhs_cname, rhs, rhs_cname, memviewslicetype, code,
105                               have_gil=False, first_assignment=False):
106    if not first_assignment:
107        code.put_xdecref_memoryviewslice(lhs_cname, have_gil=have_gil)
108
109    if not rhs.result_in_temp():
110        rhs.make_owned_memoryviewslice(code)
111
112    code.putln("%s = %s;" % (lhs_cname, rhs_cname))
113
114def get_buf_flags(specs):
115    is_c_contig, is_f_contig = is_cf_contig(specs)
116
117    if is_c_contig:
118        return memview_c_contiguous
119    elif is_f_contig:
120        return memview_f_contiguous
121
122    access, packing = zip(*specs)
123
124    if 'full' in access or 'ptr' in access:
125        return memview_full_access
126    else:
127        return memview_strided_access
128
129def insert_newaxes(memoryviewtype, n):
130    axes = [('direct', 'strided')] * n
131    axes.extend(memoryviewtype.axes)
132    return PyrexTypes.MemoryViewSliceType(memoryviewtype.dtype, axes)
133
134def broadcast_types(src, dst):
135    n = abs(src.ndim - dst.ndim)
136    if src.ndim < dst.ndim:
137        return insert_newaxes(src, n), dst
138    else:
139        return src, insert_newaxes(dst, n)
140
141def src_conforms_to_dst(src, dst, broadcast=False):
142    '''
143    returns True if src conforms to dst, False otherwise.
144
145    If conformable, the types are the same, the ndims are equal, and each axis spec is conformable.
146
147    Any packing/access spec is conformable to itself.
148
149    'direct' and 'ptr' are conformable to 'full'.
150    'contig' and 'follow' are conformable to 'strided'.
151    Any other combo is not conformable.
152    '''
153
154    if src.dtype != dst.dtype:
155        return False
156
157    if src.ndim != dst.ndim:
158        if broadcast:
159            src, dst = broadcast_types(src, dst)
160        else:
161            return False
162
163    for src_spec, dst_spec in zip(src.axes, dst.axes):
164        src_access, src_packing = src_spec
165        dst_access, dst_packing = dst_spec
166        if src_access != dst_access and dst_access != 'full':
167            return False
168        if src_packing != dst_packing and dst_packing != 'strided':
169            return False
170
171    return True
172
173def valid_memslice_dtype(dtype, i=0):
174    """
175    Return whether type dtype can be used as the base type of a
176    memoryview slice.
177
178    We support structs, numeric types and objects
179    """
180    if dtype.is_complex and dtype.real_type.is_int:
181        return False
182
183    if dtype is PyrexTypes.c_bint_type:
184        return False
185
186    if dtype.is_struct and dtype.kind == 'struct':
187        for member in dtype.scope.var_entries:
188            if not valid_memslice_dtype(member.type):
189                return False
190
191        return True
192
193    return (
194        dtype.is_error or
195        # Pointers are not valid (yet)
196        # (dtype.is_ptr and valid_memslice_dtype(dtype.base_type)) or
197        (dtype.is_array and i < 8 and
198         valid_memslice_dtype(dtype.base_type, i + 1)) or
199        dtype.is_numeric or
200        dtype.is_pyobject or
201        dtype.is_fused or # accept this as it will be replaced by specializations later
202        (dtype.is_typedef and valid_memslice_dtype(dtype.typedef_base_type))
203    )
204
205def validate_memslice_dtype(pos, dtype):
206    if not valid_memslice_dtype(dtype):
207        error(pos, "Invalid base type for memoryview slice: %s" % dtype)
208
209
210class MemoryViewSliceBufferEntry(Buffer.BufferEntry):
211    def __init__(self, entry):
212        self.entry = entry
213        self.type = entry.type
214        self.cname = entry.cname
215        self.buf_ptr = "%s.data" % self.cname
216
217        dtype = self.entry.type.dtype
218        dtype = PyrexTypes.CPtrType(dtype)
219
220        self.buf_ptr_type = dtype
221
222    def get_buf_suboffsetvars(self):
223        return self._for_all_ndim("%s.suboffsets[%d]")
224
225    def get_buf_stridevars(self):
226        return self._for_all_ndim("%s.strides[%d]")
227
228    def get_buf_shapevars(self):
229        return self._for_all_ndim("%s.shape[%d]")
230
231    def generate_buffer_lookup_code(self, code, index_cnames):
232        axes = [(dim, index_cnames[dim], access, packing)
233                    for dim, (access, packing) in enumerate(self.type.axes)]
234        return self._generate_buffer_lookup_code(code, axes)
235
236    def _generate_buffer_lookup_code(self, code, axes, cast_result=True):
237        bufp = self.buf_ptr
238        type_decl = self.type.dtype.declaration_code("")
239
240        for dim, index, access, packing in axes:
241            shape = "%s.shape[%d]" % (self.cname, dim)
242            stride = "%s.strides[%d]" % (self.cname, dim)
243            suboffset = "%s.suboffsets[%d]" % (self.cname, dim)
244
245            flag = get_memoryview_flag(access, packing)
246
247            if flag in ("generic", "generic_contiguous"):
248                # Note: we cannot do cast tricks to avoid stride multiplication
249                #       for generic_contiguous, as we may have to do (dtype *)
250                #       or (dtype **) arithmetic, we won't know which unless
251                #       we check suboffsets
252                code.globalstate.use_utility_code(memviewslice_index_helpers)
253                bufp = ('__pyx_memviewslice_index_full(%s, %s, %s, %s)' %
254                                            (bufp, index, stride, suboffset))
255
256            elif flag == "indirect":
257                bufp = "(%s + %s * %s)" % (bufp, index, stride)
258                bufp = ("(*((char **) %s) + %s)" % (bufp, suboffset))
259
260            elif flag == "indirect_contiguous":
261                # Note: we do char ** arithmetic
262                bufp = "(*((char **) %s + %s) + %s)" % (bufp, index, suboffset)
263
264            elif flag == "strided":
265                bufp = "(%s + %s * %s)" % (bufp, index, stride)
266
267            else:
268                assert flag == 'contiguous', flag
269                bufp = '((char *) (((%s *) %s) + %s))' % (type_decl, bufp, index)
270
271            bufp = '( /* dim=%d */ %s )' % (dim, bufp)
272
273        if cast_result:
274            return "((%s *) %s)" % (type_decl, bufp)
275
276        return bufp
277
278    def generate_buffer_slice_code(self, code, indices, dst, have_gil,
279                                   have_slices, directives):
280        """
281        Slice a memoryviewslice.
282
283        indices     - list of index nodes. If not a SliceNode, or NoneNode,
284                      then it must be coercible to Py_ssize_t
285
286        Simply call __pyx_memoryview_slice_memviewslice with the right
287        arguments.
288        """
289        new_ndim = 0
290        src = self.cname
291
292        def load_slice_util(name, dict):
293            proto, impl = TempitaUtilityCode.load_as_string(
294                        name, "MemoryView_C.c", context=dict)
295            return impl
296
297        all_dimensions_direct = True
298        for access, packing in self.type.axes:
299            if access != 'direct':
300                all_dimensions_direct = False
301                break
302
303        no_suboffset_dim = all_dimensions_direct and not have_slices
304        if not no_suboffset_dim:
305            suboffset_dim = code.funcstate.allocate_temp(
306                             PyrexTypes.c_int_type, False)
307            code.putln("%s = -1;" % suboffset_dim)
308
309        code.putln("%(dst)s.data = %(src)s.data;" % locals())
310        code.putln("%(dst)s.memview = %(src)s.memview;" % locals())
311        code.put_incref_memoryviewslice(dst)
312
313        dim = -1
314        for index in indices:
315            error_goto = code.error_goto(index.pos)
316            if not index.is_none:
317                dim += 1
318                access, packing = self.type.axes[dim]
319
320            if isinstance(index, ExprNodes.SliceNode):
321                # slice, unspecified dimension, or part of ellipsis
322                d = locals()
323                for s in "start stop step".split():
324                    idx = getattr(index, s)
325                    have_idx = d['have_' + s] = not idx.is_none
326                    if have_idx:
327                        d[s] = idx.result()
328                    else:
329                        d[s] = "0"
330
331                if (not d['have_start'] and
332                    not d['have_stop'] and
333                    not d['have_step']):
334                    # full slice (:), simply copy over the extent, stride
335                    # and suboffset. Also update suboffset_dim if needed
336                    d['access'] = access
337                    code.put(load_slice_util("SimpleSlice", d))
338                else:
339                    code.put(load_slice_util("ToughSlice", d))
340
341                new_ndim += 1
342
343            elif index.is_none:
344                # newaxis
345                attribs = [('shape', 1), ('strides', 0), ('suboffsets', -1)]
346                for attrib, value in attribs:
347                    code.putln("%s.%s[%d] = %d;" % (dst, attrib, new_ndim, value))
348
349                new_ndim += 1
350
351            else:
352                # normal index
353                idx = index.result()
354
355                if access == 'direct':
356                    indirect = False
357                else:
358                    indirect = True
359                    generic = (access == 'full')
360                    if new_ndim != 0:
361                        return error(index.pos,
362                                     "All preceding dimensions must be "
363                                     "indexed and not sliced")
364
365                wraparound = int(directives['wraparound'])
366                boundscheck = int(directives['boundscheck'])
367                d = locals()
368                code.put(load_slice_util("SliceIndex", d))
369
370        if not no_suboffset_dim:
371            code.funcstate.release_temp(suboffset_dim)
372
373
374def empty_slice(pos):
375    none = ExprNodes.NoneNode(pos)
376    return ExprNodes.SliceNode(pos, start=none,
377                               stop=none, step=none)
378
379def unellipsify(indices, newaxes, ndim):
380    result = []
381    seen_ellipsis = False
382    have_slices = False
383
384    n_indices = len(indices) - len(newaxes)
385
386    for index in indices:
387        if isinstance(index, ExprNodes.EllipsisNode):
388            have_slices = True
389            full_slice = empty_slice(index.pos)
390
391            if seen_ellipsis:
392                result.append(full_slice)
393            else:
394                nslices = ndim - n_indices + 1
395                result.extend([full_slice] * nslices)
396                seen_ellipsis = True
397        else:
398            have_slices = (have_slices or
399                           isinstance(index, ExprNodes.SliceNode) or
400                           index.is_none)
401            result.append(index)
402
403    result_length = len(result) - len(newaxes)
404    if result_length < ndim:
405        have_slices = True
406        nslices = ndim - result_length
407        result.extend([empty_slice(indices[-1].pos)] * nslices)
408
409    return have_slices, result
410
411def get_memoryview_flag(access, packing):
412    if access == 'full' and packing in ('strided', 'follow'):
413        return 'generic'
414    elif access == 'full' and packing == 'contig':
415        return 'generic_contiguous'
416    elif access == 'ptr' and packing in ('strided', 'follow'):
417        return 'indirect'
418    elif access == 'ptr' and packing == 'contig':
419        return 'indirect_contiguous'
420    elif access == 'direct' and packing in ('strided', 'follow'):
421        return 'strided'
422    else:
423        assert (access, packing) == ('direct', 'contig'), (access, packing)
424        return 'contiguous'
425
426def get_is_contig_func_name(c_or_f, ndim):
427    return "__pyx_memviewslice_is_%s_contig%d" % (c_or_f, ndim)
428
429def get_is_contig_utility(c_contig, ndim):
430    C = dict(context, ndim=ndim)
431    if c_contig:
432        utility = load_memview_c_utility("MemviewSliceIsCContig", C,
433                                         requires=[is_contig_utility])
434    else:
435        utility = load_memview_c_utility("MemviewSliceIsFContig", C,
436                                         requires=[is_contig_utility])
437
438    return utility
439
440def copy_src_to_dst_cname():
441    return "__pyx_memoryview_copy_contents"
442
443def verify_direct_dimensions(node):
444    for access, packing in node.type.axes:
445        if access != 'direct':
446            error(self.pos, "All dimensions must be direct")
447
448def copy_broadcast_memview_src_to_dst(src, dst, code):
449    """
450    Copy the contents of slice src to slice dst. Does not support indirect
451    slices.
452    """
453    verify_direct_dimensions(src)
454    verify_direct_dimensions(dst)
455
456    code.putln(code.error_goto_if_neg(
457            "%s(%s, %s, %d, %d, %d)" % (copy_src_to_dst_cname(),
458                                        src.result(), dst.result(),
459                                        src.type.ndim, dst.type.ndim,
460                                        dst.type.dtype.is_pyobject),
461            dst.pos))
462
463def get_1d_fill_scalar_func(type, code):
464    dtype = type.dtype
465    type_decl = dtype.declaration_code("")
466
467    dtype_name = mangle_dtype_name(dtype)
468    context = dict(dtype_name=dtype_name, type_decl=type_decl)
469    utility = load_memview_c_utility("FillStrided1DScalar", context)
470    code.globalstate.use_utility_code(utility)
471    return '__pyx_fill_slice_%s' % dtype_name
472
473def assign_scalar(dst, scalar, code):
474    """
475    Assign a scalar to a slice. dst must be a temp, scalar will be assigned
476    to a correct type and not just something assignable.
477    """
478    verify_direct_dimensions(dst)
479    dtype = dst.type.dtype
480    type_decl = dtype.declaration_code("")
481    slice_decl = dst.type.declaration_code("")
482
483    code.begin_block()
484    code.putln("%s __pyx_temp_scalar = %s;" % (type_decl, scalar.result()))
485    if dst.result_in_temp() or (dst.base.is_name and
486                                isinstance(dst.index, ExprNodes.EllipsisNode)):
487        dst_temp = dst.result()
488    else:
489        code.putln("%s __pyx_temp_slice = %s;" % (slice_decl, dst.result()))
490        dst_temp = "__pyx_temp_slice"
491
492    # with slice_iter(dst.type, dst_temp, dst.type.ndim, code) as p:
493    slice_iter_obj = slice_iter(dst.type, dst_temp, dst.type.ndim, code)
494    p = slice_iter_obj.start_loops()
495
496    if dtype.is_pyobject:
497        code.putln("Py_DECREF(*(PyObject **) %s);" % p)
498
499    code.putln("*((%s *) %s) = __pyx_temp_scalar;" % (type_decl, p))
500
501    if dtype.is_pyobject:
502        code.putln("Py_INCREF(__pyx_temp_scalar);")
503
504    slice_iter_obj.end_loops()
505    code.end_block()
506
507def slice_iter(slice_type, slice_temp, ndim, code):
508    if slice_type.is_c_contig or slice_type.is_f_contig:
509        return ContigSliceIter(slice_type, slice_temp, ndim, code)
510    else:
511        return StridedSliceIter(slice_type, slice_temp, ndim, code)
512
513class SliceIter(object):
514    def __init__(self, slice_type, slice_temp, ndim, code):
515        self.slice_type = slice_type
516        self.slice_temp = slice_temp
517        self.code = code
518        self.ndim = ndim
519
520class ContigSliceIter(SliceIter):
521    def start_loops(self):
522        code = self.code
523        code.begin_block()
524
525        type_decl = self.slice_type.dtype.declaration_code("")
526
527        total_size = ' * '.join("%s.shape[%d]" % (self.slice_temp, i)
528                                    for i in range(self.ndim))
529        code.putln("Py_ssize_t __pyx_temp_extent = %s;" % total_size)
530        code.putln("Py_ssize_t __pyx_temp_idx;")
531        code.putln("%s *__pyx_temp_pointer = (%s *) %s.data;" % (
532                            type_decl, type_decl, self.slice_temp))
533        code.putln("for (__pyx_temp_idx = 0; "
534                        "__pyx_temp_idx < __pyx_temp_extent; "
535                        "__pyx_temp_idx++) {")
536
537        return "__pyx_temp_pointer"
538
539    def end_loops(self):
540        self.code.putln("__pyx_temp_pointer += 1;")
541        self.code.putln("}")
542        self.code.end_block()
543
544class StridedSliceIter(SliceIter):
545    def start_loops(self):
546        code = self.code
547        code.begin_block()
548
549        for i in range(self.ndim):
550            t = i, self.slice_temp, i
551            code.putln("Py_ssize_t __pyx_temp_extent_%d = %s.shape[%d];" % t)
552            code.putln("Py_ssize_t __pyx_temp_stride_%d = %s.strides[%d];" % t)
553            code.putln("char *__pyx_temp_pointer_%d;" % i)
554            code.putln("Py_ssize_t __pyx_temp_idx_%d;" % i)
555
556        code.putln("__pyx_temp_pointer_0 = %s.data;" % self.slice_temp)
557
558        for i in range(self.ndim):
559            if i > 0:
560                code.putln("__pyx_temp_pointer_%d = __pyx_temp_pointer_%d;" % (i, i - 1))
561
562            code.putln("for (__pyx_temp_idx_%d = 0; "
563                            "__pyx_temp_idx_%d < __pyx_temp_extent_%d; "
564                            "__pyx_temp_idx_%d++) {" % (i, i, i, i))
565
566        return "__pyx_temp_pointer_%d" % (self.ndim - 1)
567
568    def end_loops(self):
569        code = self.code
570        for i in range(self.ndim - 1, -1, -1):
571            code.putln("__pyx_temp_pointer_%d += __pyx_temp_stride_%d;" % (i, i))
572            code.putln("}")
573
574        code.end_block()
575
576
577def copy_c_or_fortran_cname(memview):
578    if memview.is_c_contig:
579        c_or_f = 'c'
580    else:
581        c_or_f = 'f'
582
583    return "__pyx_memoryview_copy_slice_%s_%s" % (
584            memview.specialization_suffix(), c_or_f)
585
586def get_copy_new_utility(pos, from_memview, to_memview):
587    if from_memview.dtype != to_memview.dtype:
588        return error(pos, "dtypes must be the same!")
589    if len(from_memview.axes) != len(to_memview.axes):
590        return error(pos, "number of dimensions must be same")
591    if not (to_memview.is_c_contig or to_memview.is_f_contig):
592        return error(pos, "to_memview must be c or f contiguous.")
593
594    for (access, packing) in from_memview.axes:
595        if access != 'direct':
596            return error(
597                    pos, "cannot handle 'full' or 'ptr' access at this time.")
598
599    if to_memview.is_c_contig:
600        mode = 'c'
601        contig_flag = memview_c_contiguous
602    elif to_memview.is_f_contig:
603        mode = 'fortran'
604        contig_flag = memview_f_contiguous
605
606    return load_memview_c_utility(
607        "CopyContentsUtility",
608        context=dict(
609            context,
610            mode=mode,
611            dtype_decl=to_memview.dtype.declaration_code(''),
612            contig_flag=contig_flag,
613            ndim=to_memview.ndim,
614            func_cname=copy_c_or_fortran_cname(to_memview),
615            dtype_is_object=int(to_memview.dtype.is_pyobject)),
616        requires=[copy_contents_new_utility])
617
618def get_axes_specs(env, axes):
619    '''
620    get_axes_specs(env, axes) -> list of (access, packing) specs for each axis.
621    access is one of 'full', 'ptr' or 'direct'
622    packing is one of 'contig', 'strided' or 'follow'
623    '''
624
625    cythonscope = env.global_scope().context.cython_scope
626    cythonscope.load_cythonscope()
627    viewscope = cythonscope.viewscope
628
629    access_specs = tuple([viewscope.lookup(name)
630                    for name in ('full', 'direct', 'ptr')])
631    packing_specs = tuple([viewscope.lookup(name)
632                    for name in ('contig', 'strided', 'follow')])
633
634    is_f_contig, is_c_contig = False, False
635    default_access, default_packing = 'direct', 'strided'
636    cf_access, cf_packing = default_access, 'follow'
637
638    axes_specs = []
639    # analyse all axes.
640    for idx, axis in enumerate(axes):
641        if not axis.start.is_none:
642            raise CompileError(axis.start.pos,  START_ERR)
643
644        if not axis.stop.is_none:
645            raise CompileError(axis.stop.pos, STOP_ERR)
646
647        if axis.step.is_none:
648            axes_specs.append((default_access, default_packing))
649
650        elif isinstance(axis.step, IntNode):
651            # the packing for the ::1 axis is contiguous,
652            # all others are cf_packing.
653            if axis.step.compile_time_value(env) != 1:
654                raise CompileError(axis.step.pos, STEP_ERR)
655
656            axes_specs.append((cf_access, 'cfcontig'))
657
658        elif isinstance(axis.step, (NameNode, AttributeNode)):
659            entry = _get_resolved_spec(env, axis.step)
660            if entry.name in view_constant_to_access_packing:
661                axes_specs.append(view_constant_to_access_packing[entry.name])
662            else:
663                raise CompilerError(axis.step.pos, INVALID_ERR)
664
665        else:
666            raise CompileError(axis.step.pos, INVALID_ERR)
667
668    # First, find out if we have a ::1 somewhere
669    contig_dim = 0
670    is_contig = False
671    for idx, (access, packing) in enumerate(axes_specs):
672        if packing == 'cfcontig':
673            if is_contig:
674                raise CompileError(axis.step.pos, BOTH_CF_ERR)
675
676            contig_dim = idx
677            axes_specs[idx] = (access, 'contig')
678            is_contig = True
679
680    if is_contig:
681        # We have a ::1 somewhere, see if we're C or Fortran contiguous
682        if contig_dim == len(axes) - 1:
683            is_c_contig = True
684        else:
685            is_f_contig = True
686
687            if contig_dim and not axes_specs[contig_dim - 1][0] in ('full', 'ptr'):
688                raise CompileError(axes[contig_dim].pos,
689                                   "Fortran contiguous specifier must follow an indirect dimension")
690
691        if is_c_contig:
692            # Contiguous in the last dimension, find the last indirect dimension
693            contig_dim = -1
694            for idx, (access, packing) in enumerate(reversed(axes_specs)):
695                if access in ('ptr', 'full'):
696                    contig_dim = len(axes) - idx - 1
697
698        # Replace 'strided' with 'follow' for any dimension following the last
699        # indirect dimension, the first dimension or the dimension following
700        # the ::1.
701        #               int[::indirect, ::1, :, :]
702        #                                    ^  ^
703        #               int[::indirect, :, :, ::1]
704        #                               ^  ^
705        start = contig_dim + 1
706        stop = len(axes) - is_c_contig
707        for idx, (access, packing) in enumerate(axes_specs[start:stop]):
708            idx = contig_dim + 1 + idx
709            if access != 'direct':
710                raise CompileError(axes[idx].pos,
711                                   "Indirect dimension may not follow "
712                                   "Fortran contiguous dimension")
713            if packing == 'contig':
714                raise CompileError(axes[idx].pos,
715                                   "Dimension may not be contiguous")
716            axes_specs[idx] = (access, cf_packing)
717
718        if is_c_contig:
719            # For C contiguity, we need to fix the 'contig' dimension
720            # after the loop
721            a, p = axes_specs[-1]
722            axes_specs[-1] = a, 'contig'
723
724    validate_axes_specs([axis.start.pos for axis in axes],
725                        axes_specs,
726                        is_c_contig,
727                        is_f_contig)
728
729    return axes_specs
730
731def validate_axes(pos, axes):
732    if len(axes) >= Options.buffer_max_dims:
733        error(pos, "More dimensions than the maximum number"
734                   " of buffer dimensions were used.")
735        return False
736
737    return True
738
739def all(it):
740    for item in it:
741        if not item:
742            return False
743    return True
744
745def is_cf_contig(specs):
746    is_c_contig = is_f_contig = False
747
748    if (len(specs) == 1 and specs == [('direct', 'contig')]):
749        is_c_contig = True
750
751    elif (specs[-1] == ('direct','contig') and
752          all([axis == ('direct','follow') for axis in specs[:-1]])):
753        # c_contiguous: 'follow', 'follow', ..., 'follow', 'contig'
754        is_c_contig = True
755
756    elif (len(specs) > 1 and
757        specs[0] == ('direct','contig') and
758        all([axis == ('direct','follow') for axis in specs[1:]])):
759        # f_contiguous: 'contig', 'follow', 'follow', ..., 'follow'
760        is_f_contig = True
761
762    return is_c_contig, is_f_contig
763
764def get_mode(specs):
765    is_c_contig, is_f_contig = is_cf_contig(specs)
766
767    if is_c_contig:
768        return 'c'
769    elif is_f_contig:
770        return 'fortran'
771
772    for access, packing in specs:
773        if access in ('ptr', 'full'):
774            return 'full'
775
776    return 'strided'
777
778view_constant_to_access_packing = {
779    'generic':              ('full',   'strided'),
780    'strided':              ('direct', 'strided'),
781    'indirect':             ('ptr',    'strided'),
782    'generic_contiguous':   ('full',   'contig'),
783    'contiguous':           ('direct', 'contig'),
784    'indirect_contiguous':  ('ptr',    'contig'),
785}
786
787def validate_axes_specs(positions, specs, is_c_contig, is_f_contig):
788
789    packing_specs = ('contig', 'strided', 'follow')
790    access_specs = ('direct', 'ptr', 'full')
791
792    # is_c_contig, is_f_contig = is_cf_contig(specs)
793
794    has_contig = has_follow = has_strided = has_generic_contig = False
795
796    last_indirect_dimension = -1
797    for idx, (access, packing) in enumerate(specs):
798        if access == 'ptr':
799            last_indirect_dimension = idx
800
801    for idx, pos, (access, packing) in zip(xrange(len(specs)), positions, specs):
802
803        if not (access in access_specs and
804                packing in packing_specs):
805            raise CompileError(pos, "Invalid axes specification.")
806
807        if packing == 'strided':
808            has_strided = True
809        elif packing == 'contig':
810            if has_contig:
811                raise CompileError(pos, "Only one direct contiguous "
812                                        "axis may be specified.")
813
814            valid_contig_dims = last_indirect_dimension + 1, len(specs) - 1
815            if idx not in valid_contig_dims and access != 'ptr':
816                if last_indirect_dimension + 1 != len(specs) - 1:
817                    dims = "dimensions %d and %d" % valid_contig_dims
818                else:
819                    dims = "dimension %d" % valid_contig_dims[0]
820
821                raise CompileError(pos, "Only %s may be contiguous and direct" % dims)
822
823            has_contig = access != 'ptr'
824        elif packing == 'follow':
825            if has_strided:
826                raise CompileError(pos, "A memoryview cannot have both follow and strided axis specifiers.")
827            if not (is_c_contig or is_f_contig):
828                raise CompileError(pos, "Invalid use of the follow specifier.")
829
830        if access in ('ptr', 'full'):
831            has_strided = False
832
833def _get_resolved_spec(env, spec):
834    # spec must be a NameNode or an AttributeNode
835    if isinstance(spec, NameNode):
836        return _resolve_NameNode(env, spec)
837    elif isinstance(spec, AttributeNode):
838        return _resolve_AttributeNode(env, spec)
839    else:
840        raise CompileError(spec.pos, INVALID_ERR)
841
842def _resolve_NameNode(env, node):
843    try:
844        resolved_name = env.lookup(node.name).name
845    except AttributeError:
846        raise CompileError(node.pos, INVALID_ERR)
847
848    viewscope = env.global_scope().context.cython_scope.viewscope
849    entry = viewscope.lookup(resolved_name)
850    if entry is None:
851        raise CompileError(node.pos, NOT_CIMPORTED_ERR)
852
853    return entry
854
855def _resolve_AttributeNode(env, node):
856    path = []
857    while isinstance(node, AttributeNode):
858        path.insert(0, node.attribute)
859        node = node.obj
860    if isinstance(node, NameNode):
861        path.insert(0, node.name)
862    else:
863        raise CompileError(node.pos, EXPR_ERR)
864    modnames = path[:-1]
865    # must be at least 1 module name, o/w not an AttributeNode.
866    assert modnames
867
868    scope = env
869    for modname in modnames:
870        mod = scope.lookup(modname)
871        if not mod or not mod.as_module:
872            raise CompileError(
873                    node.pos, "undeclared name not builtin: %s" % modname)
874        scope = mod.as_module
875
876    entry = scope.lookup(path[-1])
877    if not entry:
878        raise CompileError(node.pos, "No such attribute '%s'" % path[-1])
879
880    return entry
881
882#
883### Utility loading
884#
885
886def load_memview_cy_utility(util_code_name, context=None, **kwargs):
887    return CythonUtilityCode.load(util_code_name, "MemoryView.pyx",
888                                  context=context, **kwargs)
889
890def load_memview_c_utility(util_code_name, context=None, **kwargs):
891    if context is None:
892        return UtilityCode.load(util_code_name, "MemoryView_C.c", **kwargs)
893    else:
894        return TempitaUtilityCode.load(util_code_name, "MemoryView_C.c",
895                                       context=context, **kwargs)
896
897def use_cython_array_utility_code(env):
898    cython_scope = env.global_scope().context.cython_scope
899    cython_scope.load_cythonscope()
900    cython_scope.viewscope.lookup('array_cwrapper').used = True
901
902context = {
903    'memview_struct_name': memview_objstruct_cname,
904    'max_dims': Options.buffer_max_dims,
905    'memviewslice_name': memviewslice_cname,
906    'memslice_init': memslice_entry_init,
907}
908memviewslice_declare_code = load_memview_c_utility(
909        "MemviewSliceStruct",
910        proto_block='utility_code_proto_before_types',
911        context=context,
912        requires=[])
913
914atomic_utility = load_memview_c_utility("Atomics", context,
915              proto_block='utility_code_proto_before_types')
916
917memviewslice_init_code = load_memview_c_utility(
918    "MemviewSliceInit",
919    context=dict(context, BUF_MAX_NDIMS=Options.buffer_max_dims),
920    requires=[memviewslice_declare_code,
921              Buffer.acquire_utility_code,
922              atomic_utility],
923)
924
925memviewslice_index_helpers = load_memview_c_utility("MemviewSliceIndex")
926
927typeinfo_to_format_code = load_memview_cy_utility(
928        "BufferFormatFromTypeInfo", requires=[Buffer._typeinfo_to_format_code])
929
930is_contig_utility = load_memview_c_utility("MemviewSliceIsContig", context)
931overlapping_utility = load_memview_c_utility("OverlappingSlices", context)
932copy_contents_new_utility = load_memview_c_utility(
933    "MemviewSliceCopyTemplate",
934    context,
935    requires=[], # require cython_array_utility_code
936)
937
938view_utility_code = load_memview_cy_utility(
939        "View.MemoryView",
940        context=context,
941        requires=[Buffer.GetAndReleaseBufferUtilityCode(),
942                  Buffer.buffer_struct_declare_code,
943                  Buffer.empty_bufstruct_utility,
944                  memviewslice_init_code,
945                  is_contig_utility,
946                  overlapping_utility,
947                  copy_contents_new_utility,
948                  ModuleNode.capsule_utility_code],
949)
950view_utility_whitelist = ('array', 'memoryview', 'array_cwrapper',
951                          'generic', 'strided', 'indirect', 'contiguous',
952                          'indirect_contiguous')
953
954memviewslice_declare_code.requires.append(view_utility_code)
955copy_contents_new_utility.requires.append(view_utility_code)