diff --git a/internal/inference/parameter.go b/internal/inference/parameter.go index cc2916e6..203cc0e0 100644 --- a/internal/inference/parameter.go +++ b/internal/inference/parameter.go @@ -366,24 +366,69 @@ func ParentAlias(join *query.Join) string { } func ExtractRelationColumns(join *query.Join) (string, string) { - relColumn := "" - refColumn := "" - sqlparser.Traverse(join.On, func(n node.Node) bool { - switch actual := n.(type) { - case *qexpr.Selector: - column := sqlparser.Stringify(actual.X) - if actual.Name == join.Alias { - if refColumn == "" { - refColumn = column - } - } else if relColumn == "" { - relColumn = column + pairs := ExtractRelationColumnPairs(join) + if len(pairs) == 0 { + return "", "" + } + return pairs[0][0], pairs[0][1] +} + +func ExtractRelationColumnPairs(join *query.Join) [][2]string { + if join == nil || join.On == nil || join.On.X == nil { + return nil + } + return collectRelationColumnPairs(join.On.X, join.Alias) +} + +func collectRelationColumnPairs(n node.Node, refAlias string) [][2]string { + switch actual := n.(type) { + case *qexpr.Binary: + actual = actual.Normalize() + op := strings.ToUpper(strings.TrimSpace(actual.Op)) + if op == "AND" { + left := collectRelationColumnPairs(actual.X, refAlias) + right := collectRelationColumnPairs(actual.Y, refAlias) + return append(left, right...) + } + if op != "=" { + return nil + } + leftAlias, leftColumn, leftOK := selectorParts(actual.X) + rightAlias, rightColumn, rightOK := selectorParts(actual.Y) + if !leftOK || !rightOK { + return nil + } + switch { + case leftAlias == refAlias: + return [][2]string{{rightColumn, leftColumn}} + case rightAlias == refAlias: + return [][2]string{{leftColumn, rightColumn}} + } + case *qexpr.Parenthesis: + return collectRelationColumnPairs(actual.X, refAlias) + } + return nil +} + +func selectorParts(n node.Node) (string, string, bool) { + switch actual := n.(type) { + case *qexpr.Selector: + return actual.Name, sqlparser.Stringify(actual.X), true + case *qexpr.Parenthesis: + return selectorParts(actual.X) + case *qexpr.Collate: + return selectorParts(actual.X) + case *qexpr.Call: + if alias, column, ok := selectorParts(actual.X); ok { + return alias, column, true + } + for _, arg := range actual.Args { + if alias, column, ok := selectorParts(arg); ok { + return alias, column, true } - return true } - return true - }) - return relColumn, refColumn + } + return "", "", false } func (p *Parameter) EnsureCodec() { diff --git a/internal/inference/spec.go b/internal/inference/spec.go index e6a56219..c5227ffe 100644 --- a/internal/inference/spec.go +++ b/internal/inference/spec.go @@ -20,12 +20,18 @@ import ( ) type ( + RelationPair struct { + ParentField *Field + KeyField *Field + } + //Relation defines relation Relation struct { Name string Join *query.Join ParentField *Field KeyField *Field + Pairs []*RelationPair Cardinality state.Cardinality *Spec } @@ -163,28 +169,35 @@ func (s *Spec) AddRelation(name string, join *query.Join, spec *Spec, cardinalit if IsToOne(join) { cardinality = state.One } - relColumn, refColumn := ExtractRelationColumns(join) - parentField := s.Type.ByColumn(relColumn) - if parentField == nil { - var available []string - for _, item := range s.Type.columnFields { - available = append(available, item.Column.Name) + pairColumns := ExtractRelationColumnPairs(join) + if len(pairColumns) == 0 { + return fmt.Errorf("failed to extract relation columns for %v", join.Alias) + } + pairs := make([]*RelationPair, 0, len(pairColumns)) + for _, pair := range pairColumns { + parentField := s.Type.ByColumn(pair[0]) + if parentField == nil { + var available []string + for _, item := range s.Type.columnFields { + available = append(available, item.Column.Name) + } + return fmt.Errorf("failed to match rel field for %v, available: %v %v", pair[0], s.Type.Name, available) } - return fmt.Errorf("failed to match rel field for %v, available: %v %v", relColumn, s.Type.Name, available) - } - - keyField := spec.Type.ByColumn(refColumn) - if keyField == nil { - var available []string - for _, item := range spec.Type.columnFields { - available = append(available, item.Column.Name) + keyField := spec.Type.ByColumn(pair[1]) + if keyField == nil { + var available []string + for _, item := range spec.Type.columnFields { + available = append(available, item.Column.Name) + } + return fmt.Errorf("failed to ref field for %v, available: %v on %v", pair[1], available, join.Alias) } - return fmt.Errorf("failed to ref field for %v, available: %v on %v", refColumn, available, join.Alias) + pairs = append(pairs, &RelationPair{ParentField: parentField, KeyField: keyField}) } rel := &Relation{Spec: spec, - KeyField: keyField, - ParentField: parentField, + KeyField: pairs[0].KeyField, + ParentField: pairs[0].ParentField, + Pairs: pairs, Name: name, Join: join, Cardinality: cardinality} diff --git a/internal/inference/state.go b/internal/inference/state.go index fa28c984..880f0be7 100644 --- a/internal/inference/state.go +++ b/internal/inference/state.go @@ -327,6 +327,13 @@ func removeBuilinExpr(query string) string { } query = strings.ReplaceAll(query, fragment, "") } + if index := strings.Index(query, "$View.ParentCompositeJoinOn"); index != -1 { + fragment := query[index:] + if endIndex := strings.Index(fragment, ")"); endIndex != -1 { + fragment = fragment[:endIndex+1] + } + query = strings.ReplaceAll(query, fragment, "") + } if !strings.Contains(query, "${predicate.") { return query diff --git a/internal/inference/tag.go b/internal/inference/tag.go index 2075a485..6e0560ac 100644 --- a/internal/inference/tag.go +++ b/internal/inference/tag.go @@ -180,19 +180,29 @@ func (t *Tags) buildRelation(spec *Spec, relation *Relation) { Table: spec.Table, } joinTag := tags.LinkOn{} - - parentColumn := relation.ParentField.Column.Name - if ns := relation.ParentField.Column.Namespace; ns != "" { - parentColumn = ns + "." + parentColumn + if len(relation.Pairs) == 0 { + relation.Pairs = []*RelationPair{{ + ParentField: relation.ParentField, + KeyField: relation.KeyField, + }} } - keyColumn := relation.KeyField.Column.Name - if ns := relation.KeyField.Column.Namespace; ns != "" { - keyColumn = ns + "." + keyColumn + for _, pair := range relation.Pairs { + if pair == nil || pair.ParentField == nil || pair.KeyField == nil { + continue + } + parentColumn := relationColumnName(pair.ParentField.Column) + if ns := pair.ParentField.Column.Namespace; ns != "" { + parentColumn = ns + "." + parentColumn + } + keyColumn := relationColumnName(pair.KeyField.Column) + if ns := pair.KeyField.Column.Namespace; ns != "" { + keyColumn = ns + "." + keyColumn + } + joinTag = joinTag.Append( + tags.WithRelLink(pair.ParentField.Name, parentColumn, nil), + tags.WithRefLink(pair.KeyField.Name, keyColumn), + ) } - joinTag = joinTag.Append( - tags.WithRelLink(relation.ParentField.Name, parentColumn, nil), - tags.WithRefLink(relation.KeyField.Name, keyColumn), - ) sqlTag := TagValue{} if rawSQL := strings.Trim(sqlparser.Stringify(join.With), " )("); rawSQL != "" { rawSQL = strings.Replace(rawSQL, "("+spec.Table+")", spec.Table, 1) @@ -205,6 +215,16 @@ func (t *Tags) buildRelation(spec *Spec, relation *Relation) { t.Set(tags.SQLTag, sqlTag) } +func relationColumnName(column *sqlparser.Column) string { + if column == nil { + return "" + } + if column.Name != "" { + return column.Name + } + return column.Alias +} + // Stringify return text representation of struct tag func (t *Tags) Stringify() string { if len(t.order) == 0 { diff --git a/internal/inference/tag_relation_test.go b/internal/inference/tag_relation_test.go new file mode 100644 index 00000000..342d2f17 --- /dev/null +++ b/internal/inference/tag_relation_test.go @@ -0,0 +1,69 @@ +package inference + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/datly/view" + vstate "github.com/viant/datly/view/state" + "github.com/viant/datly/view/tags" + "github.com/viant/sqlparser" + "github.com/viant/sqlparser/query" +) + +func TestTags_buildRelation_UsesColumnAliasWhenNameIsEmpty(t *testing.T) { + selectQuery, err := sqlparser.ParseQuery("SELECT 'app' AS SITE_TYPE_VALUE") + require.NoError(t, err) + + parentColumn := &sqlparser.Column{Name: "SITE_TYPE_VALUES"} + childColumn := &sqlparser.Column{Alias: "SITE_TYPE_VALUE"} + + parentField := &Field{ + Field: view.Field{Name: "SiteTypeValues", Schema: &vstate.Schema{}}, + Column: parentColumn, + } + keyField := &Field{ + Field: view.Field{Name: "SiteTypeValue", Schema: &vstate.Schema{}}, + Column: childColumn, + } + + spec := &Spec{Table: "ignored"} + relation := &Relation{ + Name: "siteType", + Join: &query.Join{Alias: "siteType", With: selectQuery}, + ParentField: parentField, + KeyField: keyField, + Pairs: []*RelationPair{{ + ParentField: parentField, + KeyField: keyField, + }}, + } + + field := &Field{Field: view.Field{Name: "SiteType", Schema: &vstate.Schema{}}, Tags: Tags{}} + field.Tags.buildRelation(spec, relation) + tagString := field.Tags.Stringify() + + parsed, err := tags.Parse(reflect.StructTag(tagString), nil, tags.LinkOnTag) + require.NoError(t, err) + require.Len(t, parsed.LinkOn, 1) + + var relField, relColumn, refField, refColumn string + require.NoError(t, parsed.LinkOn.ForEach(func(rf, rc, kf, kc string, include *bool) error { + relField, relColumn, refField, refColumn = rf, rc, kf, kc + return nil + })) + + require.Equal(t, "SiteTypeValues", relField) + require.Equal(t, "SITE_TYPE_VALUES", relColumn) + require.Equal(t, "SiteTypeValue", refField) + require.Equal(t, "SITE_TYPE_VALUE", refColumn) +} + +func TestType_ByColumn_MatchesAlias(t *testing.T) { + typ := &Type{columnFields: []*Field{{ + Field: view.Field{Name: "SiteTypeValue", Schema: &vstate.Schema{}}, + Column: &sqlparser.Column{Alias: "SITE_TYPE_VALUE"}, + }}} + require.NotNil(t, typ.ByColumn("SITE_TYPE_VALUE")) +} diff --git a/internal/translator/view.go b/internal/translator/view.go index 09732a30..20db586c 100644 --- a/internal/translator/view.go +++ b/internal/translator/view.go @@ -299,6 +299,9 @@ func (v *View) buildRelations(parentNamespace *Viewlet, rule *Rule) error { relNamespace.Holder = viewRelation.Holder refViewName := relNamespace.View.Name refColumn := relation.KeyField.Column.Name + if refColumn == "" { + refColumn = relation.KeyField.Column.Alias + } if ns := relation.KeyField.Column.Namespace; ns != "" { refColumn = ns + "." + refColumn } diff --git a/service/executor/expand/parent.go b/service/executor/expand/parent.go index 97787429..0a51b93e 100644 --- a/service/executor/expand/parent.go +++ b/service/executor/expand/parent.go @@ -1,9 +1,12 @@ package expand import ( + "context" "database/sql" "fmt" "github.com/viant/datly/utils/types" + sqlxconfig "github.com/viant/sqlx/io/config" + "github.com/viant/sqlx/metadata/info" "github.com/viant/xunsafe" "os" "reflect" @@ -21,6 +24,7 @@ type ( ColIn(prefix, column string) (string, error) In(prefix string) (string, error) ParentJoinOn(column string, prepend ...string) (string, error) + ParentCompositeJoinOn(prefix string, columns ...string) (string, error) AndParentJoinOn(column string) (string, error) } @@ -45,18 +49,22 @@ type ( ParentBatch interface { ColIn() []interface{} ColInBatch() []interface{} + CompositeIn() [][]interface{} + CompositeInBatch() [][]interface{} + HasComposite() bool } ViewContext struct { - Name string - Alias string - Table string - Limit int - Offset int - Page int - Args []interface{} - NonWindowSQL string - ParentValues []interface{} + Name string + Alias string + Table string + Limit int + Offset int + Page int + Args []interface{} + NonWindowSQL string + ParentValues []interface{} + ParentCompositeValues [][]interface{} expander Expander `velty:"-"` DataUnit *DataUnit `velty:"-"` @@ -102,6 +110,10 @@ func (e *MockExpander) AndParentJoinOn(column string) (string, error) { return e.ColIn("", column) } +func (e *MockExpander) ParentCompositeJoinOn(prefix string, columns ...string) (string, error) { + return "", nil +} + func (e *MockExpander) ColIn(prefix, column string) (string, error) { return "", nil } @@ -111,16 +123,48 @@ func (e *MockExpander) In(prefix string) (string, error) { } func (m *ViewContext) ParentJoinOn(column string, prepend ...string) (string, error) { + prefix := "AND" + columns := []string{column} if len(prepend) > 0 { - return m.ColIn(column, prepend[0]) + prefix = column + columns = prepend } - return m.ColIn("AND", column) + if len(columns) > 1 { + return m.parentCompositeJoinOn(prefix, columns...) + } + return m.ColIn(prefix, columns[0]) } func (m *ViewContext) AndParentJoinOn(column string) (string, error) { return m.ColIn("AND", column) } +func (m *ViewContext) ParentCompositeJoinOn(prefix string, columns ...string) (string, error) { + return m.parentCompositeJoinOn(prefix, columns...) +} + +func (m *ViewContext) parentCompositeJoinOn(prefix string, columns ...string) (string, error) { + if len(columns) == 0 { + return prefix + " 1 = 0 ", nil + } + if m.expander != nil { + return m.expander.ParentCompositeJoinOn(prefix, columns...) + } + rowCount := len(m.ParentCompositeValues) + if rowCount == 0 { + return prefix + " 1 = 0 ", nil + } + dialect, err := m.dialect() + if err != nil { + return "", err + } + if prefix != "" && !strings.HasSuffix(prefix, " ") { + prefix += " " + } + m.addCompositeBindings(m.ParentCompositeValues) + return prefix + renderCompositePredicate(dialect, columns, rowCount), nil +} + func (m *ViewContext) ColIn(prefix, column string) (string, error) { if m.expander != nil { return m.expander.ColIn(prefix, column) @@ -144,6 +188,26 @@ func (m *ViewContext) addBindings(args []interface{}) string { return bindings } +func (m *ViewContext) addCompositeBindings(rows [][]interface{}) { + for _, row := range rows { + m.DataUnit.addAll(row...) + } +} + +func (m *ViewContext) dialect() (*info.Dialect, error) { + if m == nil || m.DataUnit == nil || m.DataUnit.MetaSource == nil { + return nil, nil + } + db, err := m.DataUnit.MetaSource.Db() + if err != nil { + return nil, err + } + if db == nil { + return nil, nil + } + return sqlxconfig.Dialect(context.Background(), db) +} + func (m *ViewContext) In(prefix string) (string, error) { return m.ColIn(prefix, "") } @@ -177,6 +241,40 @@ func AsBindings(key string, values []interface{}) (column string, bindings strin } } +func defaultCompositePredicate(columns []string, rowCount int) string { + if len(columns) == 0 || rowCount <= 0 { + return "1 = 0" + } + builder := &strings.Builder{} + builder.WriteByte('(') + builder.WriteString(strings.Join(columns, ", ")) + builder.WriteString(") IN (") + for row := 0; row < rowCount; row++ { + if row > 0 { + builder.WriteString(", ") + } + builder.WriteByte('(') + for col := range columns { + if col > 0 { + builder.WriteString(", ") + } + builder.WriteByte('?') + } + builder.WriteByte(')') + } + builder.WriteByte(')') + return builder.String() +} + +func renderCompositePredicate(dialect *info.Dialect, columns []string, rowCount int) string { + if renderer, ok := any(dialect).(interface { + CompositeIn([]string, int) string + }); ok { + return renderer.CompositeIn(columns, rowCount) + } + return defaultCompositePredicate(columns, rowCount) +} + func NewViewContext(metaSource ParentSource, aSelector ParentExtras, batchData ParentBatch, options ...interface{}) *ViewContext { if metaSource == nil { return nil @@ -185,6 +283,7 @@ func NewViewContext(metaSource ParentSource, aSelector ParentExtras, batchData P var sanitizer *DataUnit var expander Expander var colInArgs []interface{} + var compositeArgs [][]interface{} for _, option := range options { switch actual := option.(type) { @@ -197,6 +296,7 @@ func NewViewContext(metaSource ParentSource, aSelector ParentExtras, batchData P if batchData != nil { colInArgs = batchData.ColInBatch() + compositeArgs = batchData.CompositeInBatch() } limit := metaSource.ResultLimit() offset := 0 @@ -215,17 +315,18 @@ func NewViewContext(metaSource ParentSource, aSelector ParentExtras, batchData P SQLExec = sanitizer.TemplateSQL } result := &ViewContext{ - expander: expander, - Name: metaSource.ViewName(), - Alias: metaSource.TableAlias(), - Table: metaSource.TableName(), - Limit: limit, - Page: page, - Offset: offset, - Args: args, - NonWindowSQL: SQLExec, - DataUnit: NewDataUnit(metaSource), - ParentValues: colInArgs, + expander: expander, + Name: metaSource.ViewName(), + Alias: metaSource.TableAlias(), + Table: metaSource.TableName(), + Limit: limit, + Page: page, + Offset: offset, + Args: args, + NonWindowSQL: SQLExec, + DataUnit: NewDataUnit(metaSource), + ParentValues: colInArgs, + ParentCompositeValues: compositeArgs, } return result diff --git a/service/executor/expand/parent_composite_test.go b/service/executor/expand/parent_composite_test.go new file mode 100644 index 00000000..2899f54e --- /dev/null +++ b/service/executor/expand/parent_composite_test.go @@ -0,0 +1,68 @@ +package expand + +import ( + "database/sql" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type compositeBatch struct { + rows [][]interface{} +} + +type mockParentSource struct{} + +func (m *mockParentSource) Db() (*sql.DB, error) { return nil, nil } +func (m *mockParentSource) ViewName() string { return "test" } +func (m *mockParentSource) TableAlias() string { return "t" } +func (m *mockParentSource) TableName() string { return "TEST" } +func (m *mockParentSource) ResultLimit() int { return 100 } + +func (b *compositeBatch) ColIn() []interface{} { return nil } +func (b *compositeBatch) ColInBatch() []interface{} { return nil } +func (b *compositeBatch) CompositeIn() [][]interface{} { return b.rows } +func (b *compositeBatch) CompositeInBatch() [][]interface{} { return b.rows } +func (b *compositeBatch) HasComposite() bool { return len(b.rows) > 0 } + +func TestViewContext_ParentCompositeJoinOn(t *testing.T) { + viewCtx := NewViewContext(&mockParentSource{}, nil, &compositeBatch{ + rows: [][]interface{}{ + {101, "A"}, + {202, "B"}, + }, + }, &DataUnit{}) + require.NotNil(t, viewCtx) + require.NotNil(t, viewCtx.DataUnit) + + sqlFragment, err := viewCtx.ParentCompositeJoinOn("AND", "t.advertiser_id", "t.val") + require.NoError(t, err) + assert.Equal(t, "AND (t.advertiser_id, t.val) IN ((?, ?), (?, ?))", sqlFragment) + assert.Equal(t, []interface{}{101, "A", 202, "B"}, viewCtx.DataUnit.ParamsGroup) +} + +func TestViewContext_ParentJoinOn_CompositeArgs(t *testing.T) { + viewCtx := NewViewContext(&mockParentSource{}, nil, &compositeBatch{ + rows: [][]interface{}{ + {101, "A"}, + {202, "B"}, + }, + }, &DataUnit{}) + require.NotNil(t, viewCtx) + + sqlFragment, err := viewCtx.ParentJoinOn("AND", "t.advertiser_id", "t.val") + require.NoError(t, err) + assert.Equal(t, "AND (t.advertiser_id, t.val) IN ((?, ?), (?, ?))", sqlFragment) + assert.Equal(t, []interface{}{101, "A", 202, "B"}, viewCtx.DataUnit.ParamsGroup) +} + +func TestViewContext_ParentCompositeJoinOn_EmptyRows(t *testing.T) { + viewCtx := NewViewContext(&mockParentSource{}, nil, &compositeBatch{}, &DataUnit{}) + require.NotNil(t, viewCtx) + + sqlFragment, err := viewCtx.ParentCompositeJoinOn("AND", "t.advertiser_id", "t.val") + require.NoError(t, err) + assert.True(t, strings.Contains(sqlFragment, "1 = 0")) +} diff --git a/service/reader/service.go b/service/reader/service.go index d362ec89..56074977 100644 --- a/service/reader/service.go +++ b/service/reader/service.go @@ -235,8 +235,12 @@ func (s *Service) afterReadAll(collectorFetchEmitted bool, collector *view.Colle func (s *Service) batchData(collector *view.Collector) *view.BatchData { batchData := &view.BatchData{} - batchData.Values, batchData.ColumnNames = collector.ParentPlaceholders() - batchData.ParentReadSize = len(batchData.Values) + batchData.Values, batchData.CompositeValues, batchData.ColumnNames = collector.ParentPlaceholders() + if batchData.HasComposite() { + batchData.ParentReadSize = len(batchData.CompositeValues) + } else { + batchData.ParentReadSize = len(batchData.Values) + } return batchData } @@ -257,7 +261,11 @@ func (s *Service) exhaustRead(ctx context.Context, view *view.View, selector *vi } func (s *Service) readObjects(ctx context.Context, session *Session, batchData *view.BatchData, view *view.View, collector *view.Collector, selector *view.Statelet, info *response.SQLExecutions) error { - batchData.ValuesBatch, batchData.Size = sliceWithLimit(batchData.Values, batchData.Size, batchData.Size+view.Batch.Size) + if batchData.HasComposite() { + batchData.CompositeValuesBatch, batchData.Size = sliceCompositeWithLimit(batchData.CompositeValues, batchData.Size, batchData.Size+view.Batch.Size) + } else { + batchData.ValuesBatch, batchData.Size = sliceWithLimit(batchData.Values, batchData.Size, batchData.Size+view.Batch.Size) + } visitor := collector.Visitor(ctx) for { err := s.queryInBatches(ctx, session, view, collector, visitor, info, batchData, selector) @@ -268,7 +276,11 @@ func (s *Service) readObjects(ctx context.Context, session *Session, batchData * break } var nextParents int - batchData.ValuesBatch, nextParents = sliceWithLimit(batchData.Values, batchData.Size, batchData.Size+view.Batch.Size) + if batchData.HasComposite() { + batchData.CompositeValuesBatch, nextParents = sliceCompositeWithLimit(batchData.CompositeValues, batchData.Size, batchData.Size+view.Batch.Size) + } else { + batchData.ValuesBatch, nextParents = sliceWithLimit(batchData.Values, batchData.Size, batchData.Size+view.Batch.Size) + } batchData.Size += nextParents } return nil diff --git a/service/reader/slice.go b/service/reader/slice.go index 07699951..129e3dc1 100644 --- a/service/reader/slice.go +++ b/service/reader/slice.go @@ -7,3 +7,10 @@ func sliceWithLimit(aSlice []interface{}, from, to int) ([]interface{}, int) { return aSlice[from:], len(aSlice) - from } + +func sliceCompositeWithLimit(aSlice [][]interface{}, from, to int) ([][]interface{}, int) { + if len(aSlice) > to { + return aSlice[from:to], to - from + } + return aSlice[from:], len(aSlice) - from +} diff --git a/service/reader/sql.go b/service/reader/sql.go index 555b2623..d2ed875c 100644 --- a/service/reader/sql.go +++ b/service/reader/sql.go @@ -18,6 +18,7 @@ import ( "github.com/viant/sqlparser/node" "github.com/viant/sqlparser/query" "github.com/viant/sqlx/io/read/cache" + "github.com/viant/sqlx/metadata/info" ) const ( @@ -46,6 +47,47 @@ func NewBuilder() *Builder { return &Builder{} } +func compositeDialect(ctx context.Context, aView *view.View) (*info.Dialect, error) { + if aView == nil || aView.Connector == nil { + return nil, nil + } + return aView.Connector.Dialect(ctx) +} + +func defaultCompositeIn(columns []string, rowCount int) string { + if len(columns) == 0 || rowCount <= 0 { + return "1 = 0" + } + builder := &strings.Builder{} + builder.WriteByte('(') + builder.WriteString(strings.Join(columns, ", ")) + builder.WriteString(") IN (") + for row := 0; row < rowCount; row++ { + if row > 0 { + builder.WriteString(", ") + } + builder.WriteByte('(') + for col := range columns { + if col > 0 { + builder.WriteString(", ") + } + builder.WriteByte('?') + } + builder.WriteByte(')') + } + builder.WriteByte(')') + return builder.String() +} + +func renderCompositeIn(dialect *info.Dialect, columns []string, rowCount int) string { + if renderer, ok := any(dialect).(interface { + CompositeIn([]string, int) string + }); ok { + return renderer.CompositeIn(columns, rowCount) + } + return defaultCompositeIn(columns, rowCount) +} + // Build builds SQL Select statement func (b *Builder) Build(ctx context.Context, opts ...BuilderOption) (*cache.ParmetrizedQuery, error) { options := newBuilderOptions(opts...) @@ -112,7 +154,9 @@ func (b *Builder) Build(ctx context.Context, opts ...BuilderOption) (*cache.Parm criteriaMeta := hasKeyword(state.Expanded, keywords.Criteria) hasCriteria := criteriaMeta.has() - b.updateColumnsIn(&commonParams, &batchData, exclude) + if err = b.updateColumnsIn(ctx, aView, &commonParams, &batchData, exclude); err != nil { + return nil, err + } if err = b.updatePagination(&commonParams, aView, statelet, exclude); err != nil { return nil, err @@ -520,18 +564,33 @@ func (b *Builder) appendCriteria(sb *strings.Builder, criteria string, addAnd bo } } -func (b *Builder) updateColumnsIn(params *view.CriteriaParam, batchData *view.BatchData, exclude *Exclude) { +func (b *Builder) updateColumnsIn(ctx context.Context, aView *view.View, params *view.CriteriaParam, batchData *view.BatchData, exclude *Exclude) error { if exclude.ColumnsIn { - return + return nil } if batchData == nil || len(batchData.ColumnNames) == 0 { - return + return nil } sb := strings.Builder{} sb.WriteString(" ") columns := len(batchData.ColumnNames) + if batchData.HasComposite() { + rowCount := len(batchData.CompositeValuesBatch) + if rowCount == 0 { + params.ColumnsIn = " 1 = 0" + return nil + } + dialect, err := compositeDialect(ctx, aView) + if err != nil { + return err + } + sb.WriteString(renderCompositeIn(dialect, batchData.ColumnNames, rowCount)) + params.ColumnsIn = sb.String() + return nil + } + switch columns { case 1: sb.WriteString(batchData.ColumnNames[0]) @@ -539,7 +598,7 @@ func (b *Builder) updateColumnsIn(params *view.CriteriaParam, batchData *view.Ba sb.WriteString("(") for i, column := range batchData.ColumnNames { if i > 0 { - sb.WriteString(",") + sb.WriteString(", ") } sb.WriteString(column) } @@ -566,6 +625,7 @@ func (b *Builder) updateColumnsIn(params *view.CriteriaParam, batchData *view.Ba } sb.WriteString(encloseFragment) params.ColumnsIn = sb.String() + return nil } func (b *Builder) appendOrderBy(sb *strings.Builder, aView *view.View, selector *view.Statelet) error { diff --git a/service/reader/sql_composite_test.go b/service/reader/sql_composite_test.go new file mode 100644 index 00000000..b111c4e4 --- /dev/null +++ b/service/reader/sql_composite_test.go @@ -0,0 +1,37 @@ +package reader + +import ( + "context" + "testing" + + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/view" +) + +func TestBuilder_Build_CompositeColumnsIn_SQLite(t *testing.T) { + aView := view.NewView("adobe", "adobe", + view.WithConnector(view.NewConnector("test", "sqlite3", ":memory:")), + view.WithColumns(view.Columns{ + &view.Column{Name: "ADVERTISER_ID", DataType: "int"}, + &view.Column{Name: "DMP_ADOBE_VALUE", DataType: "string"}, + &view.Column{Name: "ID", DataType: "int"}, + }), + ) + require.NoError(t, aView.Init(context.Background(), view.EmptyResource())) + + query, err := NewBuilder().Build(context.Background(), + WithBuilderView(aView), + WithBuilderStatelet(view.NewStatelet()), + WithBuilderBatchData(&view.BatchData{ + ColumnNames: []string{"ADVERTISER_ID", "DMP_ADOBE_VALUE"}, + CompositeValues: [][]interface{}{{101, "A"}, {202, "B"}}, + CompositeValuesBatch: [][]interface{}{{101, "A"}, {202, "B"}}, + }), + ) + require.NoError(t, err) + require.NotNil(t, query) + assert.Contains(t, query.SQL, `(ADVERTISER_ID, DMP_ADOBE_VALUE) IN ((?, ?), (?, ?))`) + assert.Equal(t, []interface{}{101, "A", 202, "B"}, query.Args) +} diff --git a/service/reader/sql_groupable_test.go b/service/reader/sql_groupable_test.go index fc45d12c..30947ecd 100644 --- a/service/reader/sql_groupable_test.go +++ b/service/reader/sql_groupable_test.go @@ -264,6 +264,43 @@ func TestBuilder_appendRelationColumn_UsesProjectedRelationAliasForGroupedDerive }) } +func TestBuilder_appendRelationColumn_UsesProjectedAliasForQualifiedSourceRelation(t *testing.T) { + builder := NewBuilder() + aView := view.NewView("comscoreContextual", "comscoreContextual", + view.WithConnector(view.NewConnector("test", "sqlite3", ":memory:")), + view.WithColumns(view.Columns{ + &view.Column{Name: "COMSCORE_CONTEXTUAL_VALUE", DataType: "string", Tag: `source:"t2.SEGMENT_ID"`}, + &view.Column{Name: "NAME", DataType: "string"}, + }), + ) + require.NoError(t, aView.Init(context.Background(), view.EmptyResource())) + + relation := &view.Relation{ + Of: &view.ReferenceView{ + On: view.Links{ + &view.Link{Field: "ComscoreContextualValue", Column: "t2.SEGMENT_ID"}, + }, + }, + } + + require.NoError(t, relation.Of.On.Init("comscoreContextual", aView)) + + t.Run("default projection does not append raw unqualified source column", func(t *testing.T) { + sb := &strings.Builder{} + require.NoError(t, builder.checkViewAndAppendRelColumn(sb, aView, relation)) + require.Equal(t, "", sb.String()) + }) + + t.Run("selector projection appends projected alias instead of raw source column", func(t *testing.T) { + sb := &strings.Builder{} + selector := view.NewStatelet() + selector.Columns = []string{"NAME"} + selector.Init(aView) + require.NoError(t, builder.checkSelectorAndAppendRelColumn(sb, aView, selector, relation)) + require.Equal(t, ", COMSCORE_CONTEXTUAL_VALUE", sb.String()) + }) +} + func newGroupableTestView(t *testing.T) *view.View { t.Helper() trueValue := true diff --git a/view/batch.go b/view/batch.go index 68f4c385..ab036de1 100644 --- a/view/batch.go +++ b/view/batch.go @@ -5,8 +5,10 @@ type BatchData struct { Size int ParentReadSize int - Values []interface{} //all values from parent - ValuesBatch []interface{} //batched values defined view.Batch.Size + Values []interface{} // all scalar values from parent + ValuesBatch []interface{} // batched scalar values + CompositeValues [][]interface{} // all composite parent tuples + CompositeValuesBatch [][]interface{} // batched composite tuples } func (b *BatchData) ColIn() []interface{} { @@ -16,3 +18,21 @@ func (b *BatchData) ColIn() []interface{} { func (b *BatchData) ColInBatch() []interface{} { return b.ValuesBatch } + +func (b *BatchData) HasComposite() bool { + return b != nil && len(b.CompositeValues) > 0 +} + +func (b *BatchData) CompositeIn() [][]interface{} { + if b == nil { + return nil + } + return b.CompositeValues +} + +func (b *BatchData) CompositeInBatch() [][]interface{} { + if b == nil { + return nil + } + return b.CompositeValuesBatch +} diff --git a/view/collector.go b/view/collector.go index 8e43229f..02dc32dc 100644 --- a/view/collector.go +++ b/view/collector.go @@ -10,6 +10,7 @@ import ( "github.com/viant/xdatly/handler" "github.com/viant/xunsafe" "reflect" + "strings" "sync" "unsafe" ) @@ -17,20 +18,23 @@ import ( // VisitorFn represents visitor function type VisitorFn func(value interface{}) error +type compositeKey string + // Collector collects and build result from View fetched from Database // If View or any of the View.With MatchStrategy support Parallel fetching, it is important to call MergeData // when all needed View was fetched type Collector struct { - Id string - mutex sync.Mutex - parent *Collector - destValue reflect.Value - appender *xunsafe.Appender - valuePosition map[string]map[string]map[interface{}][]int //stores positions in main slice, based on _field name, indexed by _field value. - types map[string]*xunsafe.Type - relation *Relation - dataSync *handler.DataSync - values map[string]*[]interface{} //acts like a buffer. Output resolved with Resolve method can't be put to the value position map + Id string + mutex sync.Mutex + parent *Collector + destValue reflect.Value + appender *xunsafe.Appender + valuePosition map[string]map[string]map[interface{}][]int //stores positions in main slice, based on _field name, indexed by _field value. + compositeValuePosition map[string]map[compositeKey][]int + types map[string]*xunsafe.Type + relation *Relation + dataSync *handler.DataSync + values map[string]*[]interface{} //acts like a buffer. Output resolved with Resolve method can't be put to the value position map // because value fetched from database was not scanned into yet. Putting value to the map as a key, would create key as a pointer to the zero value. slice *xunsafe.Slice @@ -49,6 +53,80 @@ type Collector struct { viewMetaHandler viewSummaryHandlerFn } +func relationCompositeSignature(links Links) string { + parts := make([]string, 0, len(links)) + for _, link := range links { + if link == nil { + continue + } + parts = append(parts, link.Namespace+"."+link.Column) + } + return strings.Join(parts, "|") +} + +func buildCompositeKey(values []interface{}) compositeKey { + parts := make([]string, len(values)) + for i, value := range values { + parts[i] = fmt.Sprintf("%#v", io.NormalizeKey(value)) + } + return compositeKey(strings.Join(parts, "\x1f")) +} + +func normalizeValues(value interface{}) []interface{} { + switch actual := value.(type) { + case []int: + result := make([]interface{}, 0, len(actual)) + for _, item := range actual { + result = append(result, io.NormalizeKey(item)) + } + return result + case []*int64: + result := make([]interface{}, 0, len(actual)) + for _, item := range actual { + if item == nil { + continue + } + result = append(result, io.NormalizeKey(int(*item))) + } + return result + case []int64: + result := make([]interface{}, 0, len(actual)) + for _, item := range actual { + result = append(result, io.NormalizeKey(int(item))) + } + return result + case []string: + result := make([]interface{}, 0, len(actual)) + for _, item := range actual { + result = append(result, io.NormalizeKey(item)) + } + return result + default: + return []interface{}{io.NormalizeKey(value)} + } +} + +func compositeRows(parts [][]interface{}) [][]interface{} { + if len(parts) == 0 { + return nil + } + result := make([][]interface{}, 1) + for _, values := range parts { + if len(values) == 0 { + return nil + } + next := make([][]interface{}, 0, len(result)*len(values)) + for _, existing := range result { + for _, value := range values { + row := append(append([]interface{}{}, existing...), value) + next = append(next, row) + } + } + result = next + } + return result +} + func (r *Collector) SetDest(dest interface{}) { destValue := reflect.ValueOf(dest) if destValue.Kind() == reflect.Ptr { @@ -63,27 +141,28 @@ func (r *Collector) Clone() *Collector { dest := reflect.MakeSlice(r.view.Schema.SliceType(), 0, 1) slicePtrValue.Elem().Set(dest) return &Collector{ - Id: uuid.New().String(), - parent: r.parent, - destValue: slicePtrValue, - appender: r.slice.Appender(xunsafe.ValuePointer(&slicePtrValue)), - valuePosition: r.valuePosition, - types: r.types, - relation: r.relation, - values: r.values, - slice: r.slice, - view: r.view, - relations: r.relations, - dataSync: r.dataSync, - wg: r.wg, - readAll: r.readAll, - wgDelta: r.wgDelta, - indexCounter: r.indexCounter, - manyCounter: r.manyCounter, - codecSlice: r.codecSlice, - codecSliceDest: r.codecSliceDest, - codecAppender: r.codecAppender, - viewMetaHandler: r.viewMetaHandler, + Id: uuid.New().String(), + parent: r.parent, + destValue: slicePtrValue, + appender: r.slice.Appender(xunsafe.ValuePointer(&slicePtrValue)), + valuePosition: r.valuePosition, + compositeValuePosition: r.compositeValuePosition, + types: r.types, + relation: r.relation, + values: r.values, + slice: r.slice, + view: r.view, + relations: r.relations, + dataSync: r.dataSync, + wg: r.wg, + readAll: r.readAll, + wgDelta: r.wgDelta, + indexCounter: r.indexCounter, + manyCounter: r.manyCounter, + codecSlice: r.codecSlice, + codecSliceDest: r.codecSliceDest, + codecAppender: r.codecAppender, + viewMetaHandler: r.viewMetaHandler, } } @@ -150,25 +229,36 @@ func (r *Collector) parentValuesPositions(ns string, columnName string) map[inte return result } +func (r *Collector) parentCompositePositions(relation *Relation) map[compositeKey][]int { + signature := relationCompositeSignature(relation.On) + result, ok := r.parent.compositeValuePosition[signature] + if !ok || len(result) == 0 { + r.indexParentCompositePositions(relation) + result = r.parent.compositeValuePosition[signature] + } + return result +} + // NewCollector creates a collector func NewCollector(slice *xunsafe.Slice, view *View, dest interface{}, viewMetaHandler viewSummaryHandlerFn, readAll bool) *Collector { ensuredDest := ensureDest(dest, view) wg := sync.WaitGroup{} wg.Add(1) return &Collector{ - Id: uuid.New().String(), - destValue: reflect.ValueOf(ensuredDest), - valuePosition: make(map[string]map[string]map[interface{}][]int), - appender: slice.Appender(xunsafe.AsPointer(ensuredDest)), - slice: slice, - view: view, - types: make(map[string]*xunsafe.Type), - values: make(map[string]*[]interface{}), - readAll: readAll, - wg: &wg, - dataSync: handler.NewDataSync(), - wgDelta: 1, - viewMetaHandler: viewMetaHandler, + Id: uuid.New().String(), + destValue: reflect.ValueOf(ensuredDest), + valuePosition: make(map[string]map[string]map[interface{}][]int), + compositeValuePosition: make(map[string]map[compositeKey][]int), + appender: slice.Appender(xunsafe.AsPointer(ensuredDest)), + slice: slice, + view: view, + types: make(map[string]*xunsafe.Type), + values: make(map[string]*[]interface{}), + readAll: readAll, + wg: &wg, + dataSync: handler.NewDataSync(), + wgDelta: 1, + viewMetaHandler: viewMetaHandler, } } @@ -186,6 +276,13 @@ func (r *Collector) Visitor(ctx context.Context) VisitorFn { relation := r.relation visitorRelations := RelationsSlice(r.view.With).PopulateWithVisitor() for _, rel := range visitorRelations { + if rel.IsComposite() { + signature := relationCompositeSignature(rel.On) + if _, ok := r.compositeValuePosition[signature]; !ok { + r.compositeValuePosition[signature] = map[compositeKey][]int{} + } + continue + } for _, item := range rel.On { if _, ok := r.valuePosition[item.Namespace]; !ok { r.valuePosition[item.Namespace] = map[string]map[interface{}][]int{} @@ -219,8 +316,18 @@ func (r *Collector) Visitor(ctx context.Context) VisitorFn { func (r *Collector) valueIndexer(ctx context.Context, visitorRelations []*Relation) func(value interface{}) error { distinctRelations := make([]*Relation, 0) presenceMap := map[string]map[string]bool{} + compositePresence := map[string]bool{} for i := range visitorRelations { + if visitorRelations[i].IsComposite() { + signature := relationCompositeSignature(visitorRelations[i].On) + if compositePresence[signature] { + continue + } + distinctRelations = append(distinctRelations, visitorRelations[i]) + compositePresence[signature] = true + continue + } for _, item := range visitorRelations[i].On { if _, ok := presenceMap[item.Namespace]; !ok { presenceMap[item.Namespace] = map[string]bool{} @@ -236,6 +343,10 @@ func (r *Collector) valueIndexer(ctx context.Context, visitorRelations []*Relati return func(value interface{}) error { ptr := xunsafe.AsPointer(value) for _, rel := range distinctRelations { + if rel.IsComposite() { + r.indexCompositeValueByRel(ptr, rel, r.indexCounter) + continue + } for _, link := range rel.On { if field := link.xField; field != nil { fieldValue := field.Value(ptr) @@ -252,6 +363,25 @@ func (r *Collector) valueIndexer(ctx context.Context, visitorRelations []*Relati } } +func (r *Collector) indexCompositeValueByRel(ptr unsafe.Pointer, rel *Relation, counter int) { + signature := relationCompositeSignature(rel.On) + index := r.compositeValuePosition[signature] + if index == nil { + index = map[compositeKey][]int{} + r.compositeValuePosition[signature] = index + } + valueSets := make([][]interface{}, 0, len(rel.On)) + for _, link := range rel.On { + if link == nil || link.xField == nil { + return + } + valueSets = append(valueSets, normalizeValues(link.xField.Value(ptr))) + } + for _, row := range compositeRows(valueSets) { + index[buildCompositeKey(row)] = append(index[buildCompositeKey(row)], counter) + } +} + func (r *Collector) indexValueByRel(fieldValue interface{}, rel *Relation, counter int) { switch actual := fieldValue.(type) { case []int: @@ -307,6 +437,24 @@ func (r *Collector) visitorOne(relation *Relation) func(value interface{}) error var aKey interface{} return func(owner interface{}) error { + if relation.IsComposite() { + keyParts := make([]interface{}, 0, len(links)) + for _, link := range links { + if link.xField == nil { + return fmt.Errorf("link %v field %v is not found", relation.Name, link.Column) + } + keyParts = append(keyParts, io.NormalizeKey(link.xField.Interface(xunsafe.AsPointer(owner)))) + } + positions, ok := r.parentCompositePositions(relation)[buildCompositeKey(keyParts)] + if !ok { + return nil + } + for _, index := range positions { + item := r.parent.slice.ValuePointerAt(destPtr, index) + holderField.SetValue(xunsafe.AsPointer(item), owner) + } + return nil + } for j, link := range links { if link.xField == nil { return fmt.Errorf("link %v field %v is not found", relation.Name, link.Column) @@ -374,6 +522,31 @@ func (r *Collector) ParentRow(relation *Relation) func(value interface{}) (inter } return func(child interface{}) (interface{}, error) { + if relation.IsComposite() { + keyParts := make([]interface{}, 0, len(links)) + for _, link := range links { + keyField := link.xField + if keyField == nil && xType == nil { + xType = r.types[link.Column] + values = r.values[link.Column] + } + var key interface{} + if keyField != nil { + key = keyField.Interface(xunsafe.AsPointer(child)) + } else { + key = xType.Deref((*values)[r.manyCounter]) + } + keyParts = append(keyParts, io.NormalizeKey(key)) + } + positions, ok := r.parentCompositePositions(relation)[buildCompositeKey(keyParts)] + if !ok { + return nil, fmt.Errorf(`composite key "%v" is not found`, keyParts) + } + if len(positions) > 1 { + return nil, fmt.Errorf(`composite key "%v" has more than one value`, keyParts) + } + return r.parent.slice.ValuePointerAt(destPtr, positions[0]), nil + } var key interface{} var parentPosition int for i, link := range links { @@ -413,6 +586,39 @@ func (r *Collector) visitorMany(relation *Relation) func(value interface{}) erro destPtr := xunsafe.AsPointer(dest) return func(owner interface{}) error { + if relation.IsComposite() { + keyParts := make([]interface{}, 0, len(links)) + for _, link := range links { + keyField := link.xField + if keyField == nil && xType == nil { + xType = r.types[link.Column] + values = r.values[link.Column] + } + var key interface{} + if keyField != nil { + key = keyField.Interface(xunsafe.AsPointer(owner)) + } else { + key = xType.Deref((*values)[r.manyCounter]) + r.manyCounter++ + } + keyParts = append(keyParts, io.NormalizeKey(key)) + } + positions, ok := r.parentCompositePositions(relation)[buildCompositeKey(keyParts)] + if !ok { + return nil + } + for _, index := range positions { + parentItem := r.parent.slice.ValuePointerAt(destPtr, index) + r.Lock().Lock() + sliceAddPtr := holderField.Pointer(xunsafe.AsPointer(parentItem)) + slice := relation.Of.Schema.Slice() + appender := slice.Appender(sliceAddPtr) + appender.Append(owner) + r.Lock().Unlock() + r.view.Logger.ObjectReconciling(dest, owner, parentItem, index) + } + return nil + } var key interface{} for i, link := range links { keyField := link.xField @@ -476,6 +682,13 @@ func (r *Collector) indexParentPositions(ns, name string) { r.parent.indexPositions(ns, name) } +func (r *Collector) indexParentCompositePositions(relation *Relation) { + if r.parent == nil || relation == nil { + return + } + r.parent.indexCompositePositions(relation) +} + func (r *Collector) indexPositions(ns, name string) { values := r.values[name] if values == nil { @@ -508,6 +721,46 @@ func (r *Collector) indexPositions(ns, name string) { } } +func (r *Collector) indexCompositePositions(relation *Relation) { + if relation == nil { + return + } + signature := relationCompositeSignature(relation.On) + index := r.compositeValuePosition[signature] + if index == nil { + index = map[compositeKey][]int{} + r.compositeValuePosition[signature] = index + } + destPtr := xunsafe.AsPointer(r.DestPtr()) + for position := 0; position < r.slice.Len(destPtr); position++ { + parent := r.slice.ValuePointerAt(destPtr, position) + valueSets := make([][]interface{}, 0, len(relation.On)) + for _, link := range relation.On { + if link == nil { + continue + } + if link.xField != nil { + valueSets = append(valueSets, normalizeValues(link.xField.Value(xunsafe.AsPointer(parent)))) + continue + } + values := r.values[link.Column] + if values == nil || position >= len(*values) { + valueSets = nil + break + } + xType := r.types[link.Column] + if xType == nil { + valueSets = nil + break + } + valueSets = append(valueSets, normalizeValues(xType.Deref((*values)[position]))) + } + for _, row := range compositeRows(valueSets) { + index[buildCompositeKey(row)] = append(index[buildCompositeKey(row)], position) + } + } +} + // Relations creates and register new Collector for each Relation present in the Template.Columns if View allows use Template.Columns func (r *Collector) Relations(selector *Statelet) ([]*Collector, error) { result := make([]*Collector, len(r.view.With)) @@ -539,21 +792,22 @@ func (r *Collector) Relations(selector *Statelet) ([]*Collector, error) { return nil, err } result[counter] = &Collector{ - Id: uuid.New().String(), - parent: r, - viewMetaHandler: aHandler, - destValue: destPtr, - dataSync: handler.NewDataSync(), - appender: slice.Appender(xunsafe.ValuePointer(&destPtr)), - valuePosition: make(map[string]map[string]map[interface{}][]int), - types: make(map[string]*xunsafe.Type), - values: make(map[string]*[]interface{}), - slice: slice, - view: &r.view.With[i].Of.View, - relation: r.view.With[i], - readAll: r.view.With[i].Of.MatchStrategy.ReadAll(), - wg: &wg, - wgDelta: delta, + Id: uuid.New().String(), + parent: r, + viewMetaHandler: aHandler, + destValue: destPtr, + dataSync: handler.NewDataSync(), + appender: slice.Appender(xunsafe.ValuePointer(&destPtr)), + valuePosition: make(map[string]map[string]map[interface{}][]int), + compositeValuePosition: make(map[string]map[compositeKey][]int), + types: make(map[string]*xunsafe.Type), + values: make(map[string]*[]interface{}), + slice: slice, + view: &r.view.With[i].Of.View, + relation: r.view.With[i], + readAll: r.view.With[i].Of.MatchStrategy.ReadAll(), + wg: &wg, + wgDelta: delta, } counter++ } @@ -661,6 +915,40 @@ func (r *Collector) MergeData() { func (r *Collector) mergeToParent() { links := r.relation.Of.On + if r.relation.IsComposite() { + destPtr := xunsafe.AsPointer(r.DestPtr()) + holderField := r.relation.holderField + parentSlice := r.parent.slice + parentDestPtr := xunsafe.AsPointer(r.parent.DestPtr()) + valuePositions := r.parentCompositePositions(r.relation) + + for i := 0; i < r.slice.Len(destPtr); i++ { + value := r.slice.ValuePointerAt(destPtr, i) + keyParts := make([]interface{}, 0, len(links)) + for _, link := range links { + keyParts = append(keyParts, io.NormalizeKey(link.xField.Value(xunsafe.AsPointer(value)))) + } + positions, ok := valuePositions[buildCompositeKey(keyParts)] + if !ok { + continue + } + for _, position := range positions { + parentValue := parentSlice.ValuePointerAt(parentDestPtr, position) + if r.relation.Cardinality == state.One { + at := r.slice.ValuePointerAt(destPtr, i) + holderField.SetValue(xunsafe.AsPointer(parentValue), at) + } else if r.relation.Cardinality == state.Many { + r.Lock().Lock() + appender := r.slice.Appender(holderField.ValuePointer(xunsafe.AsPointer(parentValue))) + appender.Append(value) + r.Lock().Unlock() + r.view.Logger.ObjectReconciling(r.Dest(), value, parentValue, position) + } + } + } + return + } + for i, link := range links { valuePositions := r.parentValuesPositions(r.relation.On[i].Namespace, r.relation.On[i].Column) destPtr := xunsafe.AsPointer(r.DestPtr()) @@ -698,12 +986,46 @@ func (r *Collector) mergeToParent() { // that the relation was created from, otherwise empty slice and empty string // i.e. if locators Collector collects Employee{AccountId: int}, Column.Name is account_id and Collector collects Account // it will extract and return all the AccountId that were accumulated and account_id -func (r *Collector) ParentPlaceholders() ([]interface{}, []string) { +func (r *Collector) ParentPlaceholders() ([]interface{}, [][]interface{}, []string) { if r.parent == nil || r.ReadAll() { - return []interface{}{}, nil + return []interface{}{}, nil, nil } destPtr := xunsafe.AsPointer(r.parent.DestPtr()) sliceLen := r.parent.slice.Len(destPtr) + if r.relation.IsComposite() { + result := make([][]interface{}, 0) + unique := map[compositeKey]bool{} + for i := 0; i < sliceLen; i++ { + parent := r.parent.slice.ValuePointerAt(destPtr, i) + valueSets := make([][]interface{}, 0, len(r.relation.On)) + for _, link := range r.relation.On { + field := link.xField + if field != nil { + valueSets = append(valueSets, normalizeValues(field.Value(xunsafe.AsPointer(parent)))) + continue + } + positions := r.parentValuesPositions(link.Namespace, link.Column) + if len(positions) == 0 { + valueSets = nil + break + } + values := make([]interface{}, 0, len(positions)) + for key := range positions { + values = append(values, key) + } + valueSets = append(valueSets, values) + } + for _, row := range compositeRows(valueSets) { + key := buildCompositeKey(row) + if unique[key] { + continue + } + unique[key] = true + result = append(result, row) + } + } + return nil, result, r.relation.Of.On.InColumnExpression() + } result := make([]interface{}, 0) var unique = make(map[any]bool) outer: @@ -772,7 +1094,7 @@ outer: continue outer } } - return result, r.relation.Of.On.InColumnExpression() + return result, nil, r.relation.Of.On.InColumnExpression() } func (r *Collector) WaitIfNeeded() { diff --git a/view/collector_composite_test.go b/view/collector_composite_test.go new file mode 100644 index 00000000..2e49412d --- /dev/null +++ b/view/collector_composite_test.go @@ -0,0 +1,99 @@ +package view + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/view/state" + "github.com/viant/xunsafe" +) + +type compositeParentRow struct { + AdvertiserID int + DmpAdobeValues string + Adobe []*compositeChildRow +} + +type compositeChildRow struct { + AdvertiserID int + DmpAdobeValue string +} + +func TestCollector_ParentPlaceholders_Composite(t *testing.T) { + parentView := &View{Schema: state.NewSchema(reflect.TypeOf([]*compositeParentRow{}))} + parentDest := []*compositeParentRow{ + {AdvertiserID: 101, DmpAdobeValues: "A"}, + {AdvertiserID: 202, DmpAdobeValues: "B"}, + } + parentCollector := NewCollector(parentView.Schema.Slice(), parentView, &parentDest, nil, false) + + relation := &Relation{ + Composite: true, + On: Links{ + &Link{Field: "AdvertiserID", Column: "ADVERTISER_ID", xField: xunsafe.FieldByName(reflect.TypeOf(compositeParentRow{}), "AdvertiserID")}, + &Link{Field: "DmpAdobeValues", Column: "DMP_ADOBE_VALUES", xField: xunsafe.FieldByName(reflect.TypeOf(compositeParentRow{}), "DmpAdobeValues")}, + }, + Of: &ReferenceView{ + On: Links{ + &Link{Field: "AdvertiserID", Column: "ADVERTISER_ID"}, + &Link{Field: "DmpAdobeValue", Column: "DMP_ADOBE_VALUE"}, + }, + }, + } + childCollector := &Collector{parent: parentCollector, relation: relation} + + values, composite, columns := childCollector.ParentPlaceholders() + assert.Nil(t, values) + assert.Equal(t, []string{"ADVERTISER_ID", "DMP_ADOBE_VALUE"}, columns) + assert.Equal(t, [][]interface{}{{101, "A"}, {202, "B"}}, composite) +} + +func TestCollector_MergeToParent_Composite(t *testing.T) { + parentView := &View{Schema: state.NewSchema(reflect.TypeOf([]*compositeParentRow{}))} + parentDest := []*compositeParentRow{ + {AdvertiserID: 101, DmpAdobeValues: "A"}, + {AdvertiserID: 202, DmpAdobeValues: "B"}, + } + parentCollector := NewCollector(parentView.Schema.Slice(), parentView, &parentDest, nil, false) + + childView := &View{ + Schema: state.NewSchema(reflect.TypeOf([]*compositeChildRow{})), + } + relation := &Relation{ + Composite: true, + Cardinality: state.Many, + Holder: "Adobe", + holderField: xunsafe.FieldByName(reflect.TypeOf(compositeParentRow{}), "Adobe"), + On: Links{ + &Link{Field: "AdvertiserID", Column: "ADVERTISER_ID", xField: xunsafe.FieldByName(reflect.TypeOf(compositeParentRow{}), "AdvertiserID")}, + &Link{Field: "DmpAdobeValues", Column: "DMP_ADOBE_VALUES", xField: xunsafe.FieldByName(reflect.TypeOf(compositeParentRow{}), "DmpAdobeValues")}, + }, + Of: &ReferenceView{ + View: View{Schema: state.NewSchema(reflect.TypeOf([]*compositeChildRow{}))}, + On: Links{ + &Link{Field: "AdvertiserID", Column: "ADVERTISER_ID", xField: xunsafe.FieldByName(reflect.TypeOf(compositeChildRow{}), "AdvertiserID")}, + &Link{Field: "DmpAdobeValue", Column: "DMP_ADOBE_VALUE", xField: xunsafe.FieldByName(reflect.TypeOf(compositeChildRow{}), "DmpAdobeValue")}, + }, + }, + } + + childDest := []*compositeChildRow{ + {AdvertiserID: 101, DmpAdobeValue: "A"}, + {AdvertiserID: 202, DmpAdobeValue: "B"}, + {AdvertiserID: 101, DmpAdobeValue: "Z"}, + } + childCollector := NewCollector(childView.Schema.Slice(), childView, &childDest, nil, true) + childCollector.parent = parentCollector + childCollector.relation = relation + childCollector.view = childView + childCollector.slice = childView.Schema.Slice() + + childCollector.mergeToParent() + + require.Len(t, parentDest[0].Adobe, 1) + assert.Equal(t, "A", parentDest[0].Adobe[0].DmpAdobeValue) + require.Len(t, parentDest[1].Adobe, 1) + assert.Equal(t, "B", parentDest[1].Adobe[0].DmpAdobeValue) +} diff --git a/view/column_lookup_test.go b/view/column_lookup_test.go index db598661..3e637711 100644 --- a/view/column_lookup_test.go +++ b/view/column_lookup_test.go @@ -25,3 +25,22 @@ func TestView_ColumnByName_UsesIndexedLookup(t *testing.T) { require.True(t, ok) require.Equal(t, "TAXONOMY_ID", column.Name) } + +func TestView_ColumnByName_UsesUnqualifiedSourceLookup(t *testing.T) { + aView := NewView("comscore", "comscore", + WithConnector(NewConnector("test", "sqlite3", ":memory:")), + WithColumns(Columns{ + &Column{Name: "COMSCORE_CONTEXTUAL_VALUE", DataType: "string", Tag: `source:"t2.SEGMENT_ID"`}, + &Column{Name: "NAME", DataType: "string"}, + }), + ) + require.NoError(t, aView.Init(context.Background(), EmptyResource())) + + column, ok := aView.ColumnByName("t2.SEGMENT_ID") + require.True(t, ok) + require.Equal(t, "COMSCORE_CONTEXTUAL_VALUE", column.Name) + + column, ok = aView.ColumnByName("SEGMENT_ID") + require.True(t, ok) + require.Equal(t, "COMSCORE_CONTEXTUAL_VALUE", column.Name) +} diff --git a/view/columns.go b/view/columns.go index bb9d85dc..0d6e762e 100644 --- a/view/columns.go +++ b/view/columns.go @@ -23,6 +23,9 @@ func (c Columns) Index(formatCase text.CaseFormat) NamedColumns { if aTag := c[i].Tag; aTag != "" { if src := reflect.StructTag(aTag).Get("source"); src != "" { result[strings.ToLower(src)] = c[i] + if index := strings.LastIndex(src, "."); index != -1 && index+1 < len(src) { + result.RegisterWithName(src[index+1:], c[i]) + } } } result.Register(formatCase, c[i]) diff --git a/view/relation.go b/view/relation.go index b2063128..eaeb838d 100644 --- a/view/relation.go +++ b/view/relation.go @@ -22,6 +22,7 @@ type ( Of *ReferenceView `json:",omitempty"` Caser text.CaseFormat `json:",omitempty"` Cardinality state.Cardinality `json:",omitempty"` //IsToOne, or Many + Composite bool `json:",omitempty"` On Links Holder string `json:",omitempty"` //Represents column created due to the merging. In our example it would be Employee#Account IncludeColumn bool `json:",omitempty"` //tells if Column _field should be kept in the struct type. In our example, if set false in produced Employee would be also AccountId _field @@ -267,6 +268,10 @@ func (r *Relation) TagLink() tags.LinkOn { return links } +func (r *Relation) IsComposite() bool { + return r != nil && (r.Composite || len(r.On) > 1) +} + func (l *Link) EncodeLinkTag() string { result := "" if l.Field != "" { diff --git a/view/tags/tag.go b/view/tags/tag.go index 0cb51503..d2541592 100644 --- a/view/tags/tag.go +++ b/view/tags/tag.go @@ -114,7 +114,7 @@ func (t *Tag) UpdateTag(tag reflect.StructTag) reflect.StructTag { if t.View != nil { t.appendTag(t.View, &ret) if t.View.CustomTag != "" { - rawTag = t.View.CustomTag + rawTag = normalizeCustomTag(t.View.CustomTag) } } t.appendTag(t.LinkOn, &ret) @@ -144,6 +144,21 @@ func (t *Tag) UpdateTag(tag reflect.StructTag) reflect.StructTag { return reflect.StructTag(structTag) } +func normalizeCustomTag(tag string) string { + tag = strings.TrimSpace(tag) + tag = strings.Trim(tag, "`") + if strings.HasPrefix(tag, "'") { + tag = tag[1:] + } + if strings.HasSuffix(tag, "'") { + tag = tag[:len(tag)-1] + } + if len(tag) >= 2 && strings.HasPrefix(tag, "\"") && strings.HasSuffix(tag, "\"") { + tag = tag[1 : len(tag)-1] + } + return strings.TrimSpace(tag) +} + func getTagPriority(tag *tags.Tag) int { switch tag.Name { case ParameterTag: diff --git a/view/tags/tag_custom_test.go b/view/tags/tag_custom_test.go new file mode 100644 index 00000000..09b74e96 --- /dev/null +++ b/view/tags/tag_custom_test.go @@ -0,0 +1,21 @@ +package tags + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestTag_UpdateTag_NormalizesCustomTagQuotes(t *testing.T) { + tag := &Tag{ + View: &View{ + Name: "siteType", + CustomTag: `'json:",omitempty"'`, + }, + } + + actual := string(tag.UpdateTag(reflect.StructTag(""))) + require.Contains(t, actual, `json:",omitempty"`) + require.NotContains(t, actual, `'json:",omitempty"`) +} diff --git a/view/template.go b/view/template.go index abf54916..91a2e9ad 100644 --- a/view/template.go +++ b/view/template.go @@ -460,7 +460,13 @@ func (t *Template) replacementEntry(key string, params CriteriaParam, selector * return key, criteriaExpanded, nil case keywords.ColumnsIn[1:]: - *placeholders = append(*placeholders, batchData.ValuesBatch...) + if batchData != nil && batchData.HasComposite() { + for _, row := range batchData.CompositeValuesBatch { + *placeholders = append(*placeholders, row...) + } + } else { + *placeholders = append(*placeholders, batchData.ValuesBatch...) + } return key, params.ColumnsIn, nil case keywords.SelectorCriteria[1:]: *placeholders = append(*placeholders, selector.Placeholders...)